Co-authored-by: Alberto Pose <albepose@amazon.com>
This commit is contained in:
parent
894f082429
commit
ffc4980d52
6 changed files with 82 additions and 6 deletions
|
@ -1,7 +1,11 @@
|
||||||
use std::ffi::{c_void, CStr};
|
use std::ffi::CStr;
|
||||||
use std::future;
|
use std::future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
use tokio::sync::oneshot::{channel, Receiver, Sender};
|
use tokio::sync::oneshot::{channel, Receiver, Sender};
|
||||||
|
@ -12,6 +16,13 @@ use super::{FromNapiValue, TypeName, ValidateNapiValue};
|
||||||
|
|
||||||
pub struct Promise<T: FromNapiValue> {
|
pub struct Promise<T: FromNapiValue> {
|
||||||
value: Pin<Box<Receiver<*mut Result<T>>>>,
|
value: Pin<Box<Receiver<*mut Result<T>>>>,
|
||||||
|
aborted: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FromNapiValue> Drop for Promise<T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.aborted.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FromNapiValue> TypeName for Promise<T> {
|
impl<T: FromNapiValue> TypeName for Promise<T> {
|
||||||
|
@ -90,7 +101,8 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
||||||
let mut promise_after_then = ptr::null_mut();
|
let mut promise_after_then = ptr::null_mut();
|
||||||
let mut then_js_cb = ptr::null_mut();
|
let mut then_js_cb = ptr::null_mut();
|
||||||
let (tx, rx) = channel();
|
let (tx, rx) = channel();
|
||||||
let tx_ptr = Box::into_raw(Box::new(tx));
|
let aborted = Arc::new(AtomicBool::new(false));
|
||||||
|
let tx_ptr = Box::into_raw(Box::new((tx, aborted.clone())));
|
||||||
check_status!(
|
check_status!(
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_create_function(
|
sys::napi_create_function(
|
||||||
|
@ -98,7 +110,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
||||||
then_c_string.as_ptr(),
|
then_c_string.as_ptr(),
|
||||||
4,
|
4,
|
||||||
Some(then_callback::<T>),
|
Some(then_callback::<T>),
|
||||||
tx_ptr as *mut _,
|
tx_ptr.cast(),
|
||||||
&mut then_js_cb,
|
&mut then_js_cb,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -133,7 +145,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
||||||
catch_c_string.as_ptr(),
|
catch_c_string.as_ptr(),
|
||||||
5,
|
5,
|
||||||
Some(catch_callback::<T>),
|
Some(catch_callback::<T>),
|
||||||
tx_ptr as *mut c_void,
|
tx_ptr.cast(),
|
||||||
&mut catch_js_cb,
|
&mut catch_js_cb,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -154,6 +166,7 @@ impl<T: FromNapiValue> FromNapiValue for Promise<T> {
|
||||||
)?;
|
)?;
|
||||||
Ok(Promise {
|
Ok(Promise {
|
||||||
value: Box::pin(rx),
|
value: Box::pin(rx),
|
||||||
|
aborted,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,8 +206,12 @@ unsafe extern "C" fn then_callback<T: FromNapiValue>(
|
||||||
get_cb_status == sys::Status::napi_ok,
|
get_cb_status == sys::Status::napi_ok,
|
||||||
"Get callback info from Promise::then failed"
|
"Get callback info from Promise::then failed"
|
||||||
);
|
);
|
||||||
|
let (sender, aborted) =
|
||||||
|
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
|
||||||
|
if aborted.load(Ordering::SeqCst) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) });
|
let resolve_value_t = Box::new(unsafe { T::from_napi_value(env, resolved_value[0]) });
|
||||||
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) };
|
|
||||||
sender
|
sender
|
||||||
.send(Box::into_raw(resolve_value_t))
|
.send(Box::into_raw(resolve_value_t))
|
||||||
.expect("Send Promise resolved value error");
|
.expect("Send Promise resolved value error");
|
||||||
|
@ -224,7 +241,11 @@ unsafe extern "C" fn catch_callback<T: FromNapiValue>(
|
||||||
"Get callback info from Promise::catch failed"
|
"Get callback info from Promise::catch failed"
|
||||||
);
|
);
|
||||||
let rejected_value = rejected_value[0];
|
let rejected_value = rejected_value[0];
|
||||||
let sender = unsafe { Box::from_raw(data as *mut Sender<*mut Result<T>>) };
|
let (sender, aborted) =
|
||||||
|
*unsafe { Box::from_raw(data as *mut (Sender<*mut Result<T>>, Arc<AtomicBool>)) };
|
||||||
|
if aborted.load(Ordering::SeqCst) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
sender
|
sender
|
||||||
.send(Box::into_raw(Box::new(Err(Error::from(unsafe {
|
.send(Box::into_raw(Box::new(Err(Error::from(unsafe {
|
||||||
JsUnknown::from_raw_unchecked(env, rejected_value)
|
JsUnknown::from_raw_unchecked(env, rejected_value)
|
||||||
|
|
|
@ -211,6 +211,8 @@ Generated by [AVA](https://avajs.dev).
|
||||||
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊
|
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void␊
|
||||||
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊
|
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void␊
|
||||||
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊
|
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void␊
|
||||||
|
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>␊
|
||||||
|
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>␊
|
||||||
export function getBuffer(): Buffer␊
|
export function getBuffer(): Buffer␊
|
||||||
export function appendBuffer(buf: Buffer): Buffer␊
|
export function appendBuffer(buf: Buffer): Buffer␊
|
||||||
export function getEmptyBuffer(): Buffer␊
|
export function getEmptyBuffer(): Buffer␊
|
||||||
|
|
Binary file not shown.
|
@ -123,6 +123,8 @@ import {
|
||||||
acceptThreadsafeFunctionTupleArgs,
|
acceptThreadsafeFunctionTupleArgs,
|
||||||
promiseInEither,
|
promiseInEither,
|
||||||
runScript,
|
runScript,
|
||||||
|
tsfnReturnPromise,
|
||||||
|
tsfnReturnPromiseTimeout,
|
||||||
} from '../'
|
} from '../'
|
||||||
|
|
||||||
test('export const', (t) => {
|
test('export const', (t) => {
|
||||||
|
@ -849,6 +851,34 @@ Napi4Test('accept ThreadsafeFunction tuple args', async (t) => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test('threadsafe function return Promise and await in Rust', async (t) => {
|
||||||
|
const value = await tsfnReturnPromise((err, value) => {
|
||||||
|
if (err) {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
return Promise.resolve(value + 2)
|
||||||
|
})
|
||||||
|
t.is(value, 5)
|
||||||
|
await t.throwsAsync(
|
||||||
|
() =>
|
||||||
|
tsfnReturnPromiseTimeout((err, value) => {
|
||||||
|
if (err) {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
setTimeout(() => {
|
||||||
|
resolve(value + 2)
|
||||||
|
}, 300)
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
message: 'Timeout',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
// trigger Promise.then in Rust after `Promise` is dropped
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 400))
|
||||||
|
})
|
||||||
|
|
||||||
Napi4Test('object only from js', (t) => {
|
Napi4Test('object only from js', (t) => {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
receiveObjectOnlyFromJs({
|
receiveObjectOnlyFromJs({
|
||||||
|
|
2
examples/napi/index.d.ts
vendored
2
examples/napi/index.d.ts
vendored
|
@ -201,6 +201,8 @@ export function tsfnAsyncCall(func: (...args: any[]) => any): Promise<void>
|
||||||
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void
|
export function acceptThreadsafeFunction(func: (err: Error | null, value: number) => any): void
|
||||||
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void
|
export function acceptThreadsafeFunctionFatal(func: (value: number) => any): void
|
||||||
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void
|
export function acceptThreadsafeFunctionTupleArgs(func: (err: Error | null, arg0: number, arg1: boolean, arg2: string) => any): void
|
||||||
|
export function tsfnReturnPromise(func: (err: Error | null, value: number) => any): Promise<number>
|
||||||
|
export function tsfnReturnPromiseTimeout(func: (err: Error | null, value: number) => any): Promise<number>
|
||||||
export function getBuffer(): Buffer
|
export function getBuffer(): Buffer
|
||||||
export function appendBuffer(buf: Buffer): Buffer
|
export function appendBuffer(buf: Buffer): Buffer
|
||||||
export function getEmptyBuffer(): Buffer
|
export function getEmptyBuffer(): Buffer
|
||||||
|
|
|
@ -126,3 +126,24 @@ pub fn accept_threadsafe_function_tuple_args(func: ThreadsafeFunction<(u32, bool
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn tsfn_return_promise(func: ThreadsafeFunction<u32>) -> Result<u32> {
|
||||||
|
let val = func.call_async::<Promise<u32>>(Ok(1)).await?.await?;
|
||||||
|
Ok(val + 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction<u32>) -> Result<u32> {
|
||||||
|
use tokio::time::{self, Duration};
|
||||||
|
let promise = func.call_async::<Promise<u32>>(Ok(1)).await?;
|
||||||
|
let sleep = time::sleep(Duration::from_millis(200));
|
||||||
|
tokio::select! {
|
||||||
|
_ = sleep => {
|
||||||
|
return Err(Error::new(Status::GenericFailure, "Timeout".to_owned()));
|
||||||
|
}
|
||||||
|
value = promise => {
|
||||||
|
return Ok(value? + 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue