fix(napi): panic when Promise callbacks trigger after Promise is dropped (#1469) (#1516)

Co-authored-by: Alberto Pose <albepose@amazon.com>
This commit is contained in:
Alberto Pose 2023-03-14 07:32:17 +00:00 committed by GitHub
parent 894f082429
commit ffc4980d52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 6 deletions

View file

@ -1,7 +1,11 @@
use std::ffi::{c_void, CStr}; use std::ffi::CStr;
use std::future; use std::future;
use std::pin::Pin; use std::pin::Pin;
use std::ptr; use std::ptr;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::sync::oneshot::{channel, Receiver, Sender}; use tokio::sync::oneshot::{channel, Receiver, Sender};
@ -12,6 +16,13 @@ use super::{FromNapiValue, TypeName, ValidateNapiValue};
pub struct Promise<T: FromNapiValue> { pub struct Promise<T: FromNapiValue> {
value: Pin<Box<Receiver<*mut Result<T>>>>, value: Pin<Box<Receiver<*mut Result<T>>>>,
aborted: Arc<AtomicBool>,
}
impl<T: FromNapiValue> Drop for Promise<T> {
fn drop(&mut self) {
self.aborted.store(true, Ordering::SeqCst);
}
} }
impl<T: FromNapiValue> TypeName for Promise<T> { impl<T: FromNapiValue> TypeName for Promise<T> {
@ -90,7 +101,8 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
let mut promise_after_then = ptr::null_mut(); let mut promise_after_then = ptr::null_mut();
let mut then_js_cb = ptr::null_mut(); let mut then_js_cb = ptr::null_mut();
let (tx, rx) = channel(); let (tx, rx) = channel();
let tx_ptr = Box::into_raw(Box::new(tx)); let aborted = Arc::new(AtomicBool::new(false));
let tx_ptr = Box::into_raw(Box::new((tx, aborted.clone())));
check_status!( check_status!(
unsafe { unsafe {
sys::napi_create_function( sys::napi_create_function(
@ -98,7 +110,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
then_c_string.as_ptr(), then_c_string.as_ptr(),
4, 4,
Some(then_callback::<T>), Some(then_callback::<T>),
tx_ptr as *mut _, tx_ptr.cast(),
&mut then_js_cb, &mut then_js_cb,
) )
}, },
@ -133,7 +145,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
catch_c_string.as_ptr(), catch_c_string.as_ptr(),
5, 5,
Some(catch_callback::<T>), Some(catch_callback::<T>),
tx_ptr as *mut c_void, tx_ptr.cast(),
&mut catch_js_cb, &mut catch_js_cb,
) )
}, },
@ -154,6 +166,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
)?; )?;
Ok(Promise { Ok(Promise {
value: Box::pin(rx), value: Box::pin(rx),
aborted,
}) })
} }
} }
@ -193,8 +206,12 @@ unsafe extern "C" fn then_callback<T: FromNapiValue>(
get_cb_status == sys::Status::napi_ok, get_cb_status == sys::Status::napi_ok,
"Get callback info from Promise::then failed" "Get callback info from Promise::then failed"
); );
let (sender, aborted) =
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
if aborted.load(Ordering::SeqCst) {
return this;
}
let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) }); let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) });
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) };
sender sender
.send(Box::into_raw(resolve_value_t)) .send(Box::into_raw(resolve_value_t))
.expect("Send Promise resolved value error"); .expect("Send Promise resolved value error");
@ -224,7 +241,11 @@ unsafe extern "C" fn catch_callback<T: FromNapiValue>(
"Get callback info from Promise::catch failed" "Get callback info from Promise::catch failed"
); );
let rejected_value = rejected_value[0]; let rejected_value = rejected_value[0];
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) }; let (sender, aborted) =
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
if aborted.load(Ordering::SeqCst) {
return this;
}
sender sender
.send(Box::into_raw(Box::new(Err(Error::from(unsafe { .send(Box::into_raw(Box::new(Err(Error::from(unsafe {
JsUnknown::from_raw_unchecked(env, rejected_value) JsUnknown::from_raw_unchecked(env, rejected_value)

View file

@ -211,6 +211,8 @@ Generated by [AVA](https://avajs.dev).
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊ export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊ export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊ export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>
export function getBuffer(): Buffer␊ export function getBuffer(): Buffer␊
export function appendBuffer(buf: Buffer): Buffer␊ export function appendBuffer(buf: Buffer): Buffer␊
export function getEmptyBuffer(): Buffer␊ export function getEmptyBuffer(): Buffer␊

View file

@ -123,6 +123,8 @@ import {
acceptThreadsafeFunctionTupleArgs, acceptThreadsafeFunctionTupleArgs,
promiseInEither, promiseInEither,
runScript, runScript,
tsfnReturnPromise,
tsfnReturnPromiseTimeout,
} from '../' } from '../'
test('export const', (t) => { test('export const', (t) => {
@ -849,6 +851,34 @@ Napi4Test('accept ThreadsafeFunction tuple args', async (t) => {
}) })
}) })
test('threadsafe function return Promise and await in Rust', async (t) => {
const value = await tsfnReturnPromise((err, value) => {
if (err) {
throw err
}
return Promise.resolve(value + 2)
})
t.is(value, 5)
await t.throwsAsync(
() =>
tsfnReturnPromiseTimeout((err, value) => {
if (err) {
throw err
}
return new Promise((resolve) => {
setTimeout(() => {
resolve(value + 2)
}, 300)
})
}),
{
message: 'Timeout',
},
)
// trigger Promise.then in Rust after `Promise` is dropped
await new Promise((resolve) => setTimeout(resolve, 400))
})
Napi4Test('object only from js', (t) => { Napi4Test('object only from js', (t) => {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
receiveObjectOnlyFromJs({ receiveObjectOnlyFromJs({

View file

@ -201,6 +201,8 @@ export function tsfnAsyncCall(func: (...args: any[]) => any): Promise<void>
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>
export function getBuffer(): Buffer export function getBuffer(): Buffer
export function appendBuffer(buf: Buffer): Buffer export function appendBuffer(buf: Buffer): Buffer
export function getEmptyBuffer(): Buffer export function getEmptyBuffer(): Buffer

View file

@ -126,3 +126,24 @@ pub fn accept_threadsafe_function_tuple_args(func: ThreadsafeFunction<(u32, bool
); );
}); });
} }
#[napi]
pub async fn tsfn_return_promise(func: ThreadsafeFunction<u32>) -> Result<u32> {
let val = func.call_async::<Promise<u32>>(Ok(1)).await?.await?;
Ok(val + 2)
}
#[napi]
pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction<u32>) -> Result<u32> {
use tokio::time::{self, Duration};
let promise = func.call_async::<Promise<u32>>(Ok(1)).await?;
let sleep = time::sleep(Duration::from_millis(200));
tokio::select! {
_ = sleep => {
return Err(Error::new(Status::GenericFailure, "Timeout".to_owned()));
}
value = promise => {
return Ok(value? + 2);
}
}
}