feat(napi): await Promise<T> in async fn

This commit is contained in:
LongYinan 2021-11-13 20:51:14 +08:00
parent 0d469ed8db
commit eaa96f7eb2
No known key found for this signature in database
GPG key ID: C3666B7FC82ADAD7
13 changed files with 307 additions and 60 deletions

View file

@ -155,6 +155,8 @@ pub fn ty_to_ts_type(ty: &Type, is_return_ty: bool) -> String {
.with(|c| c.borrow_mut().get(rust_ty.as_str()).cloned()) .with(|c| c.borrow_mut().get(rust_ty.as_str()).cloned())
{ {
ts_ty = Some(t); ts_ty = Some(t);
} else if rust_ty == "Promise" {
ts_ty = Some(format!("Promise<{}>", args.first().unwrap()));
} else { } else {
// there should be runtime registered type in else // there should be runtime registered type in else
ts_ty = Some(rust_ty); ts_ty = Some(rust_ty);

View file

@ -13,6 +13,7 @@ mod map;
mod nil; mod nil;
mod number; mod number;
mod object; mod object;
mod promise;
#[cfg(feature = "serde-json")] #[cfg(feature = "serde-json")]
mod serde; mod serde;
mod string; mod string;
@ -27,6 +28,7 @@ pub use either::*;
pub use function::*; pub use function::*;
pub use nil::*; pub use nil::*;
pub use object::*; pub use object::*;
pub use promise::*;
pub use string::*; pub use string::*;
pub use task::*; pub use task::*;

View file

@ -0,0 +1,170 @@
use std::ffi::{c_void, CString};
use std::future;
use std::pin::Pin;
use std::ptr;
use std::task::{Context, Poll};
use tokio::sync::oneshot::{channel, Receiver, Sender};
use crate::{check_status, Error, Result, Status};
use super::FromNapiValue;
pub struct Promise<T: FromNapiValue> {
value: Pin<Box<Receiver<*mut Result<T>>>>,
}
unsafe impl<T: FromNapiValue> Send for Promise<T> {}
unsafe impl<T: FromNapiValue> Sync for Promise<T> {}
impl<T: FromNapiValue> FromNapiValue for Promise<T> {
unsafe fn from_napi_value(
env: napi_sys::napi_env,
napi_val: napi_sys::napi_value,
) -> crate::Result<Self> {
let mut then = ptr::null_mut();
let then_c_string = CString::new("then")?;
check_status!(
napi_sys::napi_get_named_property(env, napi_val, then_c_string.as_ptr(), &mut then,),
"Failed to get then function"
)?;
let mut promise_after_then = ptr::null_mut();
let mut then_js_cb = ptr::null_mut();
let (tx, rx) = channel();
let tx_ptr = Box::into_raw(Box::new(tx));
check_status!(
napi_sys::napi_create_function(
env,
then_c_string.as_ptr(),
4,
Some(then_callback::<T>),
tx_ptr as *mut _,
&mut then_js_cb,
),
"Failed to create then callback"
)?;
check_status!(
napi_sys::napi_call_function(
env,
napi_val,
then,
1,
[then_js_cb].as_ptr(),
&mut promise_after_then,
),
"Failed to call then method"
)?;
let mut catch = ptr::null_mut();
let catch_c_string = CString::new("catch")?;
check_status!(
napi_sys::napi_get_named_property(
env,
promise_after_then,
catch_c_string.as_ptr(),
&mut catch
),
"Failed to get then function"
)?;
let mut catch_js_cb = ptr::null_mut();
check_status!(
napi_sys::napi_create_function(
env,
catch_c_string.as_ptr(),
5,
Some(catch_callback::<T>),
tx_ptr as *mut c_void,
&mut catch_js_cb
),
"Failed to create catch callback"
)?;
check_status!(
napi_sys::napi_call_function(
env,
promise_after_then,
catch,
1,
[catch_js_cb].as_ptr(),
ptr::null_mut()
),
"Failed to call catch method"
)?;
Ok(Promise {
value: Box::pin(rx),
})
}
}
impl<T: FromNapiValue> future::Future for Promise<T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.value.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(v) => Poll::Ready(
v.map_err(|e| Error::new(Status::GenericFailure, format!("{}", e)))
.and_then(|v| unsafe { *Box::from_raw(v) }.map_err(Error::from)),
),
}
}
}
unsafe extern "C" fn then_callback<T: FromNapiValue>(
env: napi_sys::napi_env,
info: napi_sys::napi_callback_info,
) -> napi_sys::napi_value {
let mut data = ptr::null_mut();
let mut resolved_value: [napi_sys::napi_value; 1] = [ptr::null_mut()];
let mut this = ptr::null_mut();
let get_cb_status = napi_sys::napi_get_cb_info(
env,
info,
&mut 1,
resolved_value.as_mut_ptr(),
&mut this,
&mut data,
);
debug_assert!(
get_cb_status == napi_sys::Status::napi_ok,
"Get callback info from Promise::then failed"
);
let resolve_value_t = Box::new(T::from_napi_value(env, resolved_value[0]));
let sender = Box::from_raw(data as *mut Sender<*mut Result<T>>);
sender
.send(Box::into_raw(resolve_value_t))
.expect("Send Promise resolved value error");
this
}
unsafe extern "C" fn catch_callback<T: FromNapiValue>(
env: napi_sys::napi_env,
info: napi_sys::napi_callback_info,
) -> napi_sys::napi_value {
let mut data = ptr::null_mut();
let mut rejected_value: [napi_sys::napi_value; 1] = [ptr::null_mut()];
let mut this = ptr::null_mut();
let mut argc = 1;
let get_cb_status = napi_sys::napi_get_cb_info(
env,
info,
&mut argc,
rejected_value.as_mut_ptr(),
&mut this,
&mut data,
);
debug_assert!(
get_cb_status == napi_sys::Status::napi_ok,
"Get callback info from Promise::catch failed"
);
let rejected_value = rejected_value[0];
let mut error_ref = ptr::null_mut();
let create_ref_status = napi_sys::napi_create_reference(env, rejected_value, 1, &mut error_ref);
debug_assert!(
create_ref_status == napi_sys::Status::napi_ok,
"Create Error reference failed"
);
let sender = Box::from_raw(data as *mut Sender<*mut Result<T>>);
sender
.send(Box::into_raw(Box::new(Err(Error::from(error_ref)))))
.expect("Send Promise resolved value error");
this
}

View file

@ -47,10 +47,10 @@ impl<'env> CallContext<'env> {
pub fn get<ArgType: NapiValue>(&self, index: usize) -> Result<ArgType> { pub fn get<ArgType: NapiValue>(&self, index: usize) -> Result<ArgType> {
if index >= self.arg_len() { if index >= self.arg_len() {
Err(Error { Err(Error::new(
status: Status::GenericFailure, Status::GenericFailure,
reason: "Arguments index out of range".to_owned(), "Arguments index out of range".to_owned(),
}) ))
} else { } else {
Ok(unsafe { ArgType::from_raw_unchecked(self.env.0, self.args[index]) }) Ok(unsafe { ArgType::from_raw_unchecked(self.env.0, self.args[index]) })
} }
@ -58,10 +58,10 @@ impl<'env> CallContext<'env> {
pub fn try_get<ArgType: NapiValue>(&self, index: usize) -> Result<Either<ArgType, JsUndefined>> { pub fn try_get<ArgType: NapiValue>(&self, index: usize) -> Result<Either<ArgType, JsUndefined>> {
if index >= self.arg_len() { if index >= self.arg_len() {
Err(Error { Err(Error::new(
status: Status::GenericFailure, Status::GenericFailure,
reason: "Arguments index out of range".to_owned(), "Arguments index out of range".to_owned(),
}) ))
} else if index < self.length { } else if index < self.length {
unsafe { ArgType::from_raw(self.env.0, self.args[index]) }.map(Either::A) unsafe { ArgType::from_raw(self.env.0, self.args[index]) }.map(Either::A)
} else { } else {

View file

@ -756,15 +756,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId; let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() { if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>; let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error { (*tagged_object).object.as_mut().ok_or_else(|| {
status: Status::InvalidArg, Error::new(
reason: "Invalid argument, nothing attach to js_object".to_owned(), Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
}) })
} else { } else {
Err(Error { Err(Error::new(
status: Status::InvalidArg, Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
}) ))
} }
} }
} }
@ -781,15 +783,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId; let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() { if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>; let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error { (*tagged_object).object.as_mut().ok_or_else(|| {
status: Status::InvalidArg, Error::new(
reason: "Invalid argument, nothing attach to js_object".to_owned(), Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
}) })
} else { } else {
Err(Error { Err(Error::new(
status: Status::InvalidArg, Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
}) ))
} }
} }
} }
@ -807,10 +811,10 @@ impl Env {
Box::from_raw(unknown_tagged_object as *mut TaggedObject<T>); Box::from_raw(unknown_tagged_object as *mut TaggedObject<T>);
Ok(()) Ok(())
} else { } else {
Err(Error { Err(Error::new(
status: Status::InvalidArg, Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
}) ))
} }
} }
} }
@ -905,15 +909,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId; let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() { if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>; let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error { (*tagged_object).object.as_mut().ok_or_else(|| {
status: Status::InvalidArg, Error::new(
reason: "nothing attach to js_external".to_owned(), Status::InvalidArg,
"nothing attach to js_external".to_owned(),
)
}) })
} else { } else {
Err(Error { Err(Error::new(
status: Status::InvalidArg, Status::InvalidArg,
reason: "T on get_value_external is not the type of wrapped object".to_owned(), "T on get_value_external is not the type of wrapped object".to_owned(),
}) ))
} }
} }
} }
@ -1103,15 +1109,17 @@ impl Env {
} }
if *type_id == TypeId::of::<T>() { if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>; let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().map(Some).ok_or(Error { (*tagged_object).object.as_mut().map(Some).ok_or_else(|| {
status: Status::InvalidArg, Error::new(
reason: "Invalid argument, nothing attach to js_object".to_owned(), Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
}) })
} else { } else {
Err(Error { Err(Error::new(
status: Status::InvalidArg, Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
}) ))
} }
} }
} }

View file

@ -23,8 +23,14 @@ pub type Result<T> = std::result::Result<T, Error>;
pub struct Error { pub struct Error {
pub status: Status, pub status: Status,
pub reason: String, pub reason: String,
// Convert raw `JsError` into Error
// Only be used in `async fn(p: Promise<T>)` scenario
pub(crate) maybe_raw: sys::napi_ref,
} }
unsafe impl Send for Error {}
unsafe impl Sync for Error {}
impl error::Error for Error {} impl error::Error for Error {}
#[cfg(feature = "serde-json")] #[cfg(feature = "serde-json")]
@ -48,6 +54,16 @@ impl From<SerdeJSONError> for Error {
} }
} }
impl From<sys::napi_ref> for Error {
fn from(value: sys::napi_ref) -> Self {
Self {
status: Status::InvalidArg,
reason: "".to_string(),
maybe_raw: value,
}
}
}
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.reason.is_empty() { if !self.reason.is_empty() {
@ -60,13 +76,18 @@ impl fmt::Display for Error {
impl Error { impl Error {
pub fn new(status: Status, reason: String) -> Self { pub fn new(status: Status, reason: String) -> Self {
Error { status, reason } Error {
status,
reason,
maybe_raw: ptr::null_mut(),
}
} }
pub fn from_status(status: Status) -> Self { pub fn from_status(status: Status) -> Self {
Error { Error {
status, status,
reason: "".to_owned(), reason: "".to_owned(),
maybe_raw: ptr::null_mut(),
} }
} }
@ -74,6 +95,7 @@ impl Error {
Error { Error {
status: Status::GenericFailure, status: Status::GenericFailure,
reason, reason,
maybe_raw: ptr::null_mut(),
} }
} }
} }
@ -83,6 +105,7 @@ impl From<std::ffi::NulError> for Error {
Error { Error {
status: Status::GenericFailure, status: Status::GenericFailure,
reason: format!("{}", error), reason: format!("{}", error),
maybe_raw: ptr::null_mut(),
} }
} }
} }
@ -92,6 +115,7 @@ impl From<std::io::Error> for Error {
Error { Error {
status: Status::GenericFailure, status: Status::GenericFailure,
reason: format!("{}", error), reason: format!("{}", error),
maybe_raw: ptr::null_mut(),
} }
} }
} }
@ -163,8 +187,10 @@ macro_rules! impl_object_methods {
pub unsafe fn throw_into(self, env: sys::napi_env) { pub unsafe fn throw_into(self, env: sys::napi_env) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let reason = self.0.reason.clone(); let reason = self.0.reason.clone();
#[cfg(debug_assertions)]
let status = self.0.status; let status = self.0.status;
if status == Status::PendingException {
return;
}
let js_error = self.into_value(env); let js_error = self.into_value(env);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let throw_status = sys::napi_throw(env, js_error); let throw_status = sys::napi_throw(env, js_error);

View file

@ -1,6 +1,7 @@
use std::ffi::CString;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::{c_char, c_void}; use std::os::raw::c_void;
use std::ptr; use std::ptr;
use crate::{check_status, sys, JsError, Result}; use crate::{check_status, sys, JsError, Result};
@ -12,7 +13,6 @@ pub struct FuturePromise<Data, Resolver: FnOnce(sys::napi_env, Data) -> Result<s
async_resource_name: sys::napi_value, async_resource_name: sys::napi_value,
resolver: Resolver, resolver: Resolver,
_data: PhantomData<Data>, _data: PhantomData<Data>,
_value: PhantomData<sys::napi_value>,
} }
unsafe impl<T, F: FnOnce(sys::napi_env, T) -> Result<sys::napi_value>> Send unsafe impl<T, F: FnOnce(sys::napi_env, T) -> Result<sys::napi_value>> Send
@ -23,26 +23,20 @@ unsafe impl<T, F: FnOnce(sys::napi_env, T) -> Result<sys::napi_value>> Send
impl<Data, Resolver: FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>> impl<Data, Resolver: FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>>
FuturePromise<Data, Resolver> FuturePromise<Data, Resolver>
{ {
pub fn new(env: sys::napi_env, dererred: sys::napi_deferred, resolver: Resolver) -> Result<Self> { pub fn new(env: sys::napi_env, deferred: sys::napi_deferred, resolver: Resolver) -> Result<Self> {
let mut async_resource_name = ptr::null_mut(); let mut async_resource_name = ptr::null_mut();
let s = "napi_resolve_promise_from_future"; let s = CString::new("napi_resolve_promise_from_future")?;
check_status!(unsafe { check_status!(unsafe {
sys::napi_create_string_utf8( sys::napi_create_string_utf8(env, s.as_ptr(), 32, &mut async_resource_name)
env,
s.as_ptr() as *const c_char,
s.len(),
&mut async_resource_name,
)
})?; })?;
Ok(FuturePromise { Ok(FuturePromise {
deferred: dererred, deferred,
resolver, resolver,
env, env,
tsfn: ptr::null_mut(), tsfn: ptr::null_mut(),
async_resource_name, async_resource_name,
_data: PhantomData, _data: PhantomData,
_value: PhantomData,
}) })
} }
@ -83,7 +77,7 @@ pub(crate) async fn resolve_from_future<Data: Send, Fut: Future<Output = Result<
check_status!(unsafe { check_status!(unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
tsfn_value.0, tsfn_value.0,
Box::into_raw(Box::from(val)) as *mut Data as *mut c_void, Box::into_raw(Box::from(val)) as *mut c_void,
sys::napi_threadsafe_function_call_mode::napi_tsfn_nonblocking, sys::napi_threadsafe_function_call_mode::napi_tsfn_nonblocking,
) )
}) })
@ -117,7 +111,26 @@ unsafe extern "C" fn call_js_cb<
debug_assert!(status == sys::Status::napi_ok, "Resolve promise failed"); debug_assert!(status == sys::Status::napi_ok, "Resolve promise failed");
} }
Err(e) => { Err(e) => {
let status = sys::napi_reject_deferred(env, deferred, JsError::from(e).into_value(env)); let status = sys::napi_reject_deferred(
env,
deferred,
if e.maybe_raw.is_null() {
JsError::from(e).into_value(env)
} else {
let mut err = ptr::null_mut();
let get_err_status = sys::napi_get_reference_value(env, e.maybe_raw, &mut err);
debug_assert!(
get_err_status == sys::Status::napi_ok,
"Get Error from Reference failed"
);
let delete_reference_status = sys::napi_delete_reference(env, e.maybe_raw);
debug_assert!(
delete_reference_status == sys::Status::napi_ok,
"Delete Error Reference failed"
);
err
},
);
debug_assert!(status == sys::Status::napi_ok, "Reject promise failed"); debug_assert!(status == sys::Status::napi_ok, "Reject promise failed");
} }
}; };

View file

@ -34,6 +34,7 @@ Generated by [AVA](https://avajs.dev).
export function fibonacci(n: number): number␊ export function fibonacci(n: number): number␊
export function listObjKeys(obj: object): Array<string> export function listObjKeys(obj: object): Array<string>
export function createObj(): object␊ export function createObj(): object␊
export function asyncPlus100(p: Promise<number>): Promise<number>
interface PackageJson {␊ interface PackageJson {␊
name: string␊ name: string␊
version: string␊ version: string␊

View file

@ -40,6 +40,7 @@ import {
createBigIntI64, createBigIntI64,
callThreadsafeFunction, callThreadsafeFunction,
threadsafeFunctionThrowError, threadsafeFunctionThrowError,
asyncPlus100,
} from '../' } from '../'
test('number', (t) => { test('number', (t) => {
@ -238,10 +239,9 @@ BigIntTest('create BigInt i64', (t) => {
t.is(createBigIntI64(), BigInt(100)) t.is(createBigIntI64(), BigInt(100))
}) })
const ThreadsafeFunctionTest = const Napi4Test = Number(process.versions.napi) >= 4 ? test : test.skip
Number(process.versions.napi) >= 4 ? test : test.skip
ThreadsafeFunctionTest('call thread safe function', (t) => { Napi4Test('call thread safe function', (t) => {
let i = 0 let i = 0
let value = 0 let value = 0
return new Promise((resolve) => { return new Promise((resolve) => {
@ -260,10 +260,26 @@ ThreadsafeFunctionTest('call thread safe function', (t) => {
}) })
}) })
ThreadsafeFunctionTest('throw error from thread safe function', async (t) => { Napi4Test('throw error from thread safe function', async (t) => {
const throwPromise = new Promise((_, reject) => { const throwPromise = new Promise((_, reject) => {
threadsafeFunctionThrowError(reject) threadsafeFunctionThrowError(reject)
}) })
const err = await t.throwsAsync(throwPromise) const err = await t.throwsAsync(throwPromise)
t.is(err.message, 'ThrowFromNative') t.is(err.message, 'ThrowFromNative')
}) })
Napi4Test('await Promise in rust', async (t) => {
const fx = 20
const result = await asyncPlus100(
new Promise((resolve) => {
setTimeout(() => resolve(fx), 50)
}),
)
t.is(result, fx + 100)
})
Napi4Test('Promise should reject raw error in rust', async (t) => {
const fxError = new Error('What is Happy Planet')
const err = await t.throwsAsync(() => asyncPlus100(Promise.reject(fxError)))
t.is(err, fxError)
})

View file

@ -24,6 +24,7 @@ export function add(a: number, b: number): number
export function fibonacci(n: number): number export function fibonacci(n: number): number
export function listObjKeys(obj: object): Array<string> export function listObjKeys(obj: object): Array<string>
export function createObj(): object export function createObj(): object
export function asyncPlus100(p: Promise<number>): Promise<number>
interface PackageJson { interface PackageJson {
name: string name: string
version: string version: string

View file

@ -15,6 +15,7 @@ mod error;
mod nullable; mod nullable;
mod number; mod number;
mod object; mod object;
mod promise;
mod serde; mod serde;
mod string; mod string;
mod task; mod task;

View file

@ -0,0 +1,7 @@
use napi::bindgen_prelude::*;
#[napi]
pub async fn async_plus_100(p: Promise<u32>) -> Result<u32> {
let v = p.await?;
Ok(v + 100)
}