diff --git a/crates/napi/src/bindgen_runtime/js_values/promise.rs b/crates/napi/src/bindgen_runtime/js_values/promise.rs index f6168d1f..09fe1fd8 100644 --- a/crates/napi/src/bindgen_runtime/js_values/promise.rs +++ b/crates/napi/src/bindgen_runtime/js_values/promise.rs @@ -1,7 +1,11 @@ -use std::ffi::{c_void, CStr}; +use std::ffi::CStr; use std::future; use std::pin::Pin; use std::ptr; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::task::{Context, Poll}; use tokio::sync::oneshot::{channel, Receiver, Sender}; @@ -12,6 +16,13 @@ use super::{FromNapiValue, TypeName, ValidateNapiValue}; pub struct Promise { value: Pin>>>, + aborted: Arc, +} + +impl Drop for Promise { + fn drop(&mut self) { + self.aborted.store(true, Ordering::SeqCst); + } } impl TypeName for Promise { @@ -90,7 +101,8 @@ impl FromNapiValue for Promise { let mut promise_after_then = ptr::null_mut(); let mut then_js_cb = ptr::null_mut(); let (tx, rx) = channel(); - let tx_ptr = Box::into_raw(Box::new(tx)); + let aborted = Arc::new(AtomicBool::new(false)); + let tx_ptr = Box::into_raw(Box::new((tx, aborted.clone()))); check_status!( unsafe { sys::napi_create_function( @@ -98,7 +110,7 @@ impl FromNapiValue for Promise { then_c_string.as_ptr(), 4, Some(then_callback::), - tx_ptr as *mut _, + tx_ptr.cast(), &mut then_js_cb, ) }, @@ -133,7 +145,7 @@ impl FromNapiValue for Promise { catch_c_string.as_ptr(), 5, Some(catch_callback::), - tx_ptr as *mut c_void, + tx_ptr.cast(), &mut catch_js_cb, ) }, @@ -154,6 +166,7 @@ impl FromNapiValue for Promise { )?; Ok(Promise { value: Box::pin(rx), + aborted, }) } } @@ -193,8 +206,12 @@ unsafe extern "C" fn then_callback( get_cb_status == sys::Status::napi_ok, "Get callback info from Promise::then failed" ); + let (sender, aborted) = + *unsafe { Box::from_raw(data as *mut (Sender<*mut Result>, Arc)) }; + if aborted.load(Ordering::SeqCst) { + return this; + } let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) }); - let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result>) }; sender .send(Box::into_raw(resolve_value_t)) .expect("Send Promise resolved value error"); @@ -224,7 +241,11 @@ unsafe extern "C" fn catch_callback( "Get callback info from Promise::catch failed" ); let rejected_value = rejected_value[0]; - let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result>) }; + let (sender, aborted) = + *unsafe { Box::from_raw(data as *mut (Sender<*mut Result>, Arc)) }; + if aborted.load(Ordering::SeqCst) { + return this; + } sender .send(Box::into_raw(Box::new(Err(Error::from(unsafe { JsUnknown::from_raw_unchecked(env, rejected_value) diff --git a/examples/napi/__test__/typegen.spec.ts.md b/examples/napi/__test__/typegen.spec.ts.md index 4bc5c775..849279e2 100644 --- a/examples/napi/__test__/typegen.spec.ts.md +++ b/examples/napi/__test__/typegen.spec.ts.md @@ -211,6 +211,8 @@ Generated by [AVA](https://avajs.dev). export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊ export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊ export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊ + export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise␊ + export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise␊ export function getBuffer(): Buffer␊ export function appendBuffer(buf: Buffer): Buffer␊ export function getEmptyBuffer(): Buffer␊ diff --git a/examples/napi/__test__/typegen.spec.ts.snap b/examples/napi/__test__/typegen.spec.ts.snap index 8bcb74de..7ab51122 100644 Binary files a/examples/napi/__test__/typegen.spec.ts.snap and b/examples/napi/__test__/typegen.spec.ts.snap differ diff --git a/examples/napi/__test__/values.spec.ts b/examples/napi/__test__/values.spec.ts index 474f2742..40b807a0 100644 --- a/examples/napi/__test__/values.spec.ts +++ b/examples/napi/__test__/values.spec.ts @@ -123,6 +123,8 @@ import { acceptThreadsafeFunctionTupleArgs, promiseInEither, runScript, + tsfnReturnPromise, + tsfnReturnPromiseTimeout, } from '../' test('export const', (t) => { @@ -849,6 +851,34 @@ Napi4Test('accept ThreadsafeFunction tuple args', async (t) => { }) }) +test('threadsafe function return Promise and await in Rust', async (t) => { + const value = await tsfnReturnPromise((err, value) => { + if (err) { + throw err + } + return Promise.resolve(value + 2) + }) + t.is(value, 5) + await t.throwsAsync( + () => + tsfnReturnPromiseTimeout((err, value) => { + if (err) { + throw err + } + return new Promise((resolve) => { + setTimeout(() => { + resolve(value + 2) + }, 300) + }) + }), + { + message: 'Timeout', + }, + ) + // trigger Promise.then in Rust after `Promise` is dropped + await new Promise((resolve) => setTimeout(resolve, 400)) +}) + Napi4Test('object only from js', (t) => { return new Promise((resolve, reject) => { receiveObjectOnlyFromJs({ diff --git a/examples/napi/index.d.ts b/examples/napi/index.d.ts index 3f63dfb5..ce0d13ad 100644 --- a/examples/napi/index.d.ts +++ b/examples/napi/index.d.ts @@ -201,6 +201,8 @@ export function tsfnAsyncCall(func: (...args: any[]) => any): Promise export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void +export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise +export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise export function getBuffer(): Buffer export function appendBuffer(buf: Buffer): Buffer export function getEmptyBuffer(): Buffer diff --git a/examples/napi/src/threadsafe_function.rs b/examples/napi/src/threadsafe_function.rs index a6bfea0f..a8573ec1 100644 --- a/examples/napi/src/threadsafe_function.rs +++ b/examples/napi/src/threadsafe_function.rs @@ -126,3 +126,24 @@ pub fn accept_threadsafe_function_tuple_args(func: ThreadsafeFunction<(u32, bool ); }); } + +#[napi] +pub async fn tsfn_return_promise(func: ThreadsafeFunction) -> Result { + let val = func.call_async::>(Ok(1)).await?.await?; + Ok(val + 2) +} + +#[napi] +pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction) -> Result { + use tokio::time::{self, Duration}; + let promise = func.call_async::>(Ok(1)).await?; + let sleep = time::sleep(Duration::from_millis(200)); + tokio::select! { + _ = sleep => { + return Err(Error::new(Status::GenericFailure, "Timeout".to_owned())); + } + value = promise => { + return Ok(value? + 2); + } + } +}