diff --git a/crates/backend/src/ast.rs b/crates/backend/src/ast.rs index 570f0451..6e1be828 100644 --- a/crates/backend/src/ast.rs +++ b/crates/backend/src/ast.rs @@ -15,6 +15,7 @@ pub struct NapiFn { pub vis: syn::Visibility, pub parent: Option, pub strict: bool, + pub return_if_invalid: bool, pub js_mod: Option, pub ts_generic_types: Option, pub ts_args_type: Option, diff --git a/crates/backend/src/codegen/fn.rs b/crates/backend/src/codegen/fn.rs index d4d6a5db..e66efee3 100644 --- a/crates/backend/src/codegen/fn.rs +++ b/crates/backend/src/codegen/fn.rs @@ -188,7 +188,17 @@ impl NapiFn { } } _ => { - let type_check = if self.strict { + let type_check = if self.return_if_invalid { + quote! { + if let Ok(maybe_promise) = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index)) { + if !maybe_promise.is_null() { + return Ok(maybe_promise); + } + } else { + return Ok(std::ptr::null_mut()); + } + } + } else if self.strict { quote! { let maybe_promise = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index))?; if !maybe_promise.is_null() { diff --git a/crates/macro/src/parser/attrs.rs b/crates/macro/src/parser/attrs.rs index c317a87a..54598ba1 100644 --- a/crates/macro/src/parser/attrs.rs +++ b/crates/macro/src/parser/attrs.rs @@ -50,6 +50,7 @@ macro_rules! attrgen { (readonly, Readonly(Span)), (skip, Skip(Span)), (strict, Strict(Span)), + (return_if_invalid, ReturnIfInvalid(Span)), (object, Object(Span)), (namespace, Namespace(Span, String, Span)), (iterator, Iterator(Span)), diff --git a/crates/macro/src/parser/mod.rs b/crates/macro/src/parser/mod.rs index 34fa64fb..58f2ccd3 100644 --- a/crates/macro/src/parser/mod.rs +++ b/crates/macro/src/parser/mod.rs @@ -670,6 +670,7 @@ fn napi_fn_from_decl( comments: extract_doc_comments(&attrs), attrs, strict: opts.strict().is_some(), + return_if_invalid: opts.return_if_invalid().is_some(), js_mod: opts.namespace().map(|(m, _)| m.to_owned()), ts_generic_types: opts.ts_generic_types().map(|(m, _)| m.to_owned()), ts_args_type: opts.ts_args_type().map(|(m, _)| m.to_owned()), diff --git a/examples/napi/__test__/strict.spec.ts b/examples/napi/__test__/strict.spec.ts index d469ae2a..62b98e86 100644 --- a/examples/napi/__test__/strict.spec.ts +++ b/examples/napi/__test__/strict.spec.ts @@ -17,6 +17,8 @@ import { validateSymbol, validateNull, validateUndefined, + returnUndefinedIfInvalid, + returnUndefinedIfInvalidPromise, } from '../index' test('should validate array', (t) => { @@ -166,3 +168,15 @@ test('should validate undefined', (t) => { message: 'Expect value to be Undefined, but received Number', }) }) + +test('should return undefined if arg is invalid', (t) => { + t.is(returnUndefinedIfInvalid(true), false) + // @ts-expect-error + t.is(returnUndefinedIfInvalid(1), undefined) +}) + +test('should return Promise.reject() if arg is not Promise', async (t) => { + t.is(await returnUndefinedIfInvalidPromise(Promise.resolve(true)), false) + // @ts-expect-error + await t.throwsAsync(() => returnUndefinedIfInvalidPromise(1)) +}) diff --git a/examples/napi/__test__/typegen.spec.ts.md b/examples/napi/__test__/typegen.spec.ts.md index 4271220a..5702f86f 100644 --- a/examples/napi/__test__/typegen.spec.ts.md +++ b/examples/napi/__test__/typegen.spec.ts.md @@ -110,6 +110,8 @@ Generated by [AVA](https://avajs.dev). export function validatePromise(p: Promise): Promise␊ export function validateString(s: string): string␊ export function validateSymbol(s: symbol): boolean␊ + export function returnUndefinedIfInvalid(input: boolean): boolean␊ + export function returnUndefinedIfInvalidPromise(input: Promise): Promise␊ export function tsRename(a: { foo: number }): string[]␊ export function overrideIndividualArgOnFunction(notOverridden: string, f: () => string, notOverridden2: number): string␊ export function overrideIndividualArgOnFunctionWithCbArg(callback: (town: string, name?: string | undefined | null) => string, notOverridden: number): object␊ diff --git a/examples/napi/__test__/typegen.spec.ts.snap b/examples/napi/__test__/typegen.spec.ts.snap index a5f80b44..3a445810 100644 Binary files a/examples/napi/__test__/typegen.spec.ts.snap and b/examples/napi/__test__/typegen.spec.ts.snap differ diff --git a/examples/napi/index.d.ts b/examples/napi/index.d.ts index 5d537f7f..8c4bf1b7 100644 --- a/examples/napi/index.d.ts +++ b/examples/napi/index.d.ts @@ -100,6 +100,8 @@ export function validateNumber(i: number): number export function validatePromise(p: Promise): Promise export function validateString(s: string): string export function validateSymbol(s: symbol): boolean +export function returnUndefinedIfInvalid(input: boolean): boolean +export function returnUndefinedIfInvalidPromise(input: Promise): Promise export function tsRename(a: { foo: number }): string[] export function overrideIndividualArgOnFunction(notOverridden: string, f: () => string, notOverridden2: number): string export function overrideIndividualArgOnFunctionWithCbArg(callback: (town: string, name?: string | undefined | null) => string, notOverridden: number): object diff --git a/examples/napi/src/fn_strict.rs b/examples/napi/src/fn_strict.rs index 2b2a9cf0..5ce758a4 100644 --- a/examples/napi/src/fn_strict.rs +++ b/examples/napi/src/fn_strict.rs @@ -87,3 +87,14 @@ fn validate_string(s: String) -> String { fn validate_symbol(_s: JsSymbol) -> bool { true } + +#[napi(return_if_invalid)] +fn return_undefined_if_invalid(input: bool) -> bool { + !input +} + +#[napi(return_if_invalid)] +async fn return_undefined_if_invalid_promise(input: Promise) -> Result { + let input_value = input.await?; + Ok(!input_value) +}