Co-authored-by: Alberto Pose <albepose@amazon.com>
This commit is contained in:
parent
894f082429
commit
ffc4980d52
6 changed files with 82 additions and 6 deletions
|
@ -1,7 +1,11 @@
|
|||
use std::ffi::{c_void, CStr};
|
||||
use std::ffi::CStr;
|
||||
use std::future;
|
||||
use std::pin::Pin;
|
||||
use std::ptr;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::sync::oneshot::{channel, Receiver, Sender};
|
||||
|
@ -12,6 +16,13 @@ use super::{FromNapiValue, TypeName, ValidateNapiValue};
|
|||
|
||||
pub struct Promise<T: FromNapiValue> {
|
||||
value: Pin<Box<Receiver<*mut Result<T>>>>,
|
||||
aborted: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<T: FromNapiValue> Drop for Promise<T> {
|
||||
fn drop(&mut self) {
|
||||
self.aborted.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FromNapiValue> TypeName for Promise<T> {
|
||||
|
@ -90,7 +101,8 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
|||
let mut promise_after_then = ptr::null_mut();
|
||||
let mut then_js_cb = ptr::null_mut();
|
||||
let (tx, rx) = channel();
|
||||
let tx_ptr = Box::into_raw(Box::new(tx));
|
||||
let aborted = Arc::new(AtomicBool::new(false));
|
||||
let tx_ptr = Box::into_raw(Box::new((tx, aborted.clone())));
|
||||
check_status!(
|
||||
unsafe {
|
||||
sys::napi_create_function(
|
||||
|
@ -98,7 +110,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
|||
then_c_string.as_ptr(),
|
||||
4,
|
||||
Some(then_callback::<T>),
|
||||
tx_ptr as *mut _,
|
||||
tx_ptr.cast(),
|
||||
&mut then_js_cb,
|
||||
)
|
||||
},
|
||||
|
@ -133,7 +145,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
|||
catch_c_string.as_ptr(),
|
||||
5,
|
||||
Some(catch_callback::<T>),
|
||||
tx_ptr as *mut c_void,
|
||||
tx_ptr.cast(),
|
||||
&mut catch_js_cb,
|
||||
)
|
||||
},
|
||||
|
@ -154,6 +166,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
|||
)?;
|
||||
Ok(Promise {
|
||||
value: Box::pin(rx),
|
||||
aborted,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -193,8 +206,12 @@ unsafe extern "C" fn then_callback<T: FromNapiValue>(
|
|||
get_cb_status == sys::Status::napi_ok,
|
||||
"Get callback info from Promise::then failed"
|
||||
);
|
||||
let (sender, aborted) =
|
||||
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
|
||||
if aborted.load(Ordering::SeqCst) {
|
||||
return this;
|
||||
}
|
||||
let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) });
|
||||
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) };
|
||||
sender
|
||||
.send(Box::into_raw(resolve_value_t))
|
||||
.expect("Send Promise resolved value error");
|
||||
|
@ -224,7 +241,11 @@ unsafe extern "C" fn catch_callback<T: FromNapiValue>(
|
|||
"Get callback info from Promise::catch failed"
|
||||
);
|
||||
let rejected_value = rejected_value[0];
|
||||
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) };
|
||||
let (sender, aborted) =
|
||||
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
|
||||
if aborted.load(Ordering::SeqCst) {
|
||||
return this;
|
||||
}
|
||||
sender
|
||||
.send(Box::into_raw(Box::new(Err(Error::from(unsafe {
|
||||
JsUnknown::from_raw_unchecked(env, rejected_value)
|
||||
|
|
|
@ -211,6 +211,8 @@ Generated by [AVA](https://avajs.dev).
|
|||
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊
|
||||
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊
|
||||
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊
|
||||
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>␊
|
||||
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>␊
|
||||
export function getBuffer(): Buffer␊
|
||||
export function appendBuffer(buf: Buffer): Buffer␊
|
||||
export function getEmptyBuffer(): Buffer␊
|
||||
|
|
Binary file not shown.
|
@ -123,6 +123,8 @@ import {
|
|||
acceptThreadsafeFunctionTupleArgs,
|
||||
promiseInEither,
|
||||
runScript,
|
||||
tsfnReturnPromise,
|
||||
tsfnReturnPromiseTimeout,
|
||||
} from '../'
|
||||
|
||||
test('export const', (t) => {
|
||||
|
@ -849,6 +851,34 @@ Napi4Test('accept ThreadsafeFunction tuple args', async (t) => {
|
|||
})
|
||||
})
|
||||
|
||||
test('threadsafe function return Promise and await in Rust', async (t) => {
|
||||
const value = await tsfnReturnPromise((err, value) => {
|
||||
if (err) {
|
||||
throw err
|
||||
}
|
||||
return Promise.resolve(value + 2)
|
||||
})
|
||||
t.is(value, 5)
|
||||
await t.throwsAsync(
|
||||
() =>
|
||||
tsfnReturnPromiseTimeout((err, value) => {
|
||||
if (err) {
|
||||
throw err
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
resolve(value + 2)
|
||||
}, 300)
|
||||
})
|
||||
}),
|
||||
{
|
||||
message: 'Timeout',
|
||||
},
|
||||
)
|
||||
// trigger Promise.then in Rust after `Promise` is dropped
|
||||
await new Promise((resolve) => setTimeout(resolve, 400))
|
||||
})
|
||||
|
||||
Napi4Test('object only from js', (t) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
receiveObjectOnlyFromJs({
|
||||
|
|
2
examples/napi/index.d.ts
vendored
2
examples/napi/index.d.ts
vendored
|
@ -201,6 +201,8 @@ export function tsfnAsyncCall(func: (...args: any[]) => any): Promise<void>
|
|||
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void
|
||||
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void
|
||||
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void
|
||||
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>
|
||||
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>
|
||||
export function getBuffer(): Buffer
|
||||
export function appendBuffer(buf: Buffer): Buffer
|
||||
export function getEmptyBuffer(): Buffer
|
||||
|
|
|
@ -126,3 +126,24 @@ pub fn accept_threadsafe_function_tuple_args(func: ThreadsafeFunction<(u32, bool
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn tsfn_return_promise(func: ThreadsafeFunction<u32>) -> Result<u32> {
|
||||
let val = func.call_async::<Promise<u32>>(Ok(1)).await?.await?;
|
||||
Ok(val + 2)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction<u32>) -> Result<u32> {
|
||||
use tokio::time::{self, Duration};
|
||||
let promise = func.call_async::<Promise<u32>>(Ok(1)).await?;
|
||||
let sleep = time::sleep(Duration::from_millis(200));
|
||||
tokio::select! {
|
||||
_ = sleep => {
|
||||
return Err(Error::new(Status::GenericFailure, "Timeout".to_owned()));
|
||||
}
|
||||
value = promise => {
|
||||
return Ok(value? + 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue