From eaa96f7eb2656d8dd95651a22a880f323e3a2ecb Mon Sep 17 00:00:00 2001 From: LongYinan Date: Sat, 13 Nov 2021 20:51:14 +0800 Subject: [PATCH] feat(napi): await Promise in async fn --- crates/backend/src/typegen.rs | 2 + crates/napi/src/bindgen_runtime/js_values.rs | 2 + .../src/bindgen_runtime/js_values/promise.rs | 170 ++++++++++++++++++ crates/napi/src/call_context.rs | 16 +- crates/napi/src/env.rs | 72 ++++---- crates/napi/src/error.rs | 30 +++- crates/napi/src/promise.rs | 41 +++-- examples/napi/__test__/typegen.spec.ts.md | 1 + examples/napi/__test__/typegen.spec.ts.snap | Bin 998 -> 1013 bytes examples/napi/__test__/values.spec.ts | 24 ++- examples/napi/index.d.ts | 1 + examples/napi/src/lib.rs | 1 + examples/napi/src/promise.rs | 7 + 13 files changed, 307 insertions(+), 60 deletions(-) create mode 100644 crates/napi/src/bindgen_runtime/js_values/promise.rs create mode 100644 examples/napi/src/promise.rs diff --git a/crates/backend/src/typegen.rs b/crates/backend/src/typegen.rs index 9a8a5232..9cb61de4 100644 --- a/crates/backend/src/typegen.rs +++ b/crates/backend/src/typegen.rs @@ -155,6 +155,8 @@ pub fn ty_to_ts_type(ty: &Type, is_return_ty: bool) -> String { .with(|c| c.borrow_mut().get(rust_ty.as_str()).cloned()) { ts_ty = Some(t); + } else if rust_ty == "Promise" { + ts_ty = Some(format!("Promise<{}>", args.first().unwrap())); } else { // there should be runtime registered type in else ts_ty = Some(rust_ty); diff --git a/crates/napi/src/bindgen_runtime/js_values.rs b/crates/napi/src/bindgen_runtime/js_values.rs index 30efd91e..ec880d07 100644 --- a/crates/napi/src/bindgen_runtime/js_values.rs +++ b/crates/napi/src/bindgen_runtime/js_values.rs @@ -13,6 +13,7 @@ mod map; mod nil; mod number; mod object; +mod promise; #[cfg(feature = "serde-json")] mod serde; mod string; @@ -27,6 +28,7 @@ pub use either::*; pub use function::*; pub use nil::*; pub use object::*; +pub use promise::*; pub use string::*; pub use task::*; diff --git a/crates/napi/src/bindgen_runtime/js_values/promise.rs b/crates/napi/src/bindgen_runtime/js_values/promise.rs new file mode 100644 index 00000000..c1604b35 --- /dev/null +++ b/crates/napi/src/bindgen_runtime/js_values/promise.rs @@ -0,0 +1,170 @@ +use std::ffi::{c_void, CString}; +use std::future; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +use tokio::sync::oneshot::{channel, Receiver, Sender}; + +use crate::{check_status, Error, Result, Status}; + +use super::FromNapiValue; + +pub struct Promise { + value: Pin>>>, +} + +unsafe impl Send for Promise {} +unsafe impl Sync for Promise {} + +impl FromNapiValue for Promise { + unsafe fn from_napi_value( + env: napi_sys::napi_env, + napi_val: napi_sys::napi_value, + ) -> crate::Result { + let mut then = ptr::null_mut(); + let then_c_string = CString::new("then")?; + check_status!( + napi_sys::napi_get_named_property(env, napi_val, then_c_string.as_ptr(), &mut then,), + "Failed to get then function" + )?; + let mut promise_after_then = ptr::null_mut(); + let mut then_js_cb = ptr::null_mut(); + let (tx, rx) = channel(); + let tx_ptr = Box::into_raw(Box::new(tx)); + check_status!( + napi_sys::napi_create_function( + env, + then_c_string.as_ptr(), + 4, + Some(then_callback::), + tx_ptr as *mut _, + &mut then_js_cb, + ), + "Failed to create then callback" + )?; + check_status!( + napi_sys::napi_call_function( + env, + napi_val, + then, + 1, + [then_js_cb].as_ptr(), + &mut promise_after_then, + ), + "Failed to call then method" + )?; + let mut catch = ptr::null_mut(); + let catch_c_string = CString::new("catch")?; + check_status!( + napi_sys::napi_get_named_property( + env, + promise_after_then, + catch_c_string.as_ptr(), + &mut catch + ), + "Failed to get then function" + )?; + let mut catch_js_cb = ptr::null_mut(); + check_status!( + napi_sys::napi_create_function( + env, + catch_c_string.as_ptr(), + 5, + Some(catch_callback::), + tx_ptr as *mut c_void, + &mut catch_js_cb + ), + "Failed to create catch callback" + )?; + check_status!( + napi_sys::napi_call_function( + env, + promise_after_then, + catch, + 1, + [catch_js_cb].as_ptr(), + ptr::null_mut() + ), + "Failed to call catch method" + )?; + Ok(Promise { + value: Box::pin(rx), + }) + } +} + +impl future::Future for Promise { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.value.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(v) => Poll::Ready( + v.map_err(|e| Error::new(Status::GenericFailure, format!("{}", e))) + .and_then(|v| unsafe { *Box::from_raw(v) }.map_err(Error::from)), + ), + } + } +} + +unsafe extern "C" fn then_callback( + env: napi_sys::napi_env, + info: napi_sys::napi_callback_info, +) -> napi_sys::napi_value { + let mut data = ptr::null_mut(); + let mut resolved_value: [napi_sys::napi_value; 1] = [ptr::null_mut()]; + let mut this = ptr::null_mut(); + let get_cb_status = napi_sys::napi_get_cb_info( + env, + info, + &mut 1, + resolved_value.as_mut_ptr(), + &mut this, + &mut data, + ); + debug_assert!( + get_cb_status == napi_sys::Status::napi_ok, + "Get callback info from Promise::then failed" + ); + let resolve_value_t = Box::new(T::from_napi_value(env, resolved_value[0])); + let sender = Box::from_raw(data as *mut Sender<*mut Result>); + sender + .send(Box::into_raw(resolve_value_t)) + .expect("Send Promise resolved value error"); + this +} + +unsafe extern "C" fn catch_callback( + env: napi_sys::napi_env, + info: napi_sys::napi_callback_info, +) -> napi_sys::napi_value { + let mut data = ptr::null_mut(); + let mut rejected_value: [napi_sys::napi_value; 1] = [ptr::null_mut()]; + let mut this = ptr::null_mut(); + let mut argc = 1; + let get_cb_status = napi_sys::napi_get_cb_info( + env, + info, + &mut argc, + rejected_value.as_mut_ptr(), + &mut this, + &mut data, + ); + debug_assert!( + get_cb_status == napi_sys::Status::napi_ok, + "Get callback info from Promise::catch failed" + ); + let rejected_value = rejected_value[0]; + let mut error_ref = ptr::null_mut(); + let create_ref_status = napi_sys::napi_create_reference(env, rejected_value, 1, &mut error_ref); + debug_assert!( + create_ref_status == napi_sys::Status::napi_ok, + "Create Error reference failed" + ); + let sender = Box::from_raw(data as *mut Sender<*mut Result>); + sender + .send(Box::into_raw(Box::new(Err(Error::from(error_ref))))) + .expect("Send Promise resolved value error"); + this +} diff --git a/crates/napi/src/call_context.rs b/crates/napi/src/call_context.rs index a2b34bd8..7d846546 100644 --- a/crates/napi/src/call_context.rs +++ b/crates/napi/src/call_context.rs @@ -47,10 +47,10 @@ impl<'env> CallContext<'env> { pub fn get(&self, index: usize) -> Result { if index >= self.arg_len() { - Err(Error { - status: Status::GenericFailure, - reason: "Arguments index out of range".to_owned(), - }) + Err(Error::new( + Status::GenericFailure, + "Arguments index out of range".to_owned(), + )) } else { Ok(unsafe { ArgType::from_raw_unchecked(self.env.0, self.args[index]) }) } @@ -58,10 +58,10 @@ impl<'env> CallContext<'env> { pub fn try_get(&self, index: usize) -> Result> { if index >= self.arg_len() { - Err(Error { - status: Status::GenericFailure, - reason: "Arguments index out of range".to_owned(), - }) + Err(Error::new( + Status::GenericFailure, + "Arguments index out of range".to_owned(), + )) } else if index < self.length { unsafe { ArgType::from_raw(self.env.0, self.args[index]) }.map(Either::A) } else { diff --git a/crates/napi/src/env.rs b/crates/napi/src/env.rs index aa466f71..f76c8764 100644 --- a/crates/napi/src/env.rs +++ b/crates/napi/src/env.rs @@ -756,15 +756,17 @@ impl Env { let type_id = unknown_tagged_object as *const TypeId; if *type_id == TypeId::of::() { let tagged_object = unknown_tagged_object as *mut TaggedObject; - (*tagged_object).object.as_mut().ok_or(Error { - status: Status::InvalidArg, - reason: "Invalid argument, nothing attach to js_object".to_owned(), + (*tagged_object).object.as_mut().ok_or_else(|| { + Error::new( + Status::InvalidArg, + "Invalid argument, nothing attach to js_object".to_owned(), + ) }) } else { - Err(Error { - status: Status::InvalidArg, - reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), - }) + Err(Error::new( + Status::InvalidArg, + "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), + )) } } } @@ -781,15 +783,17 @@ impl Env { let type_id = unknown_tagged_object as *const TypeId; if *type_id == TypeId::of::() { let tagged_object = unknown_tagged_object as *mut TaggedObject; - (*tagged_object).object.as_mut().ok_or(Error { - status: Status::InvalidArg, - reason: "Invalid argument, nothing attach to js_object".to_owned(), + (*tagged_object).object.as_mut().ok_or_else(|| { + Error::new( + Status::InvalidArg, + "Invalid argument, nothing attach to js_object".to_owned(), + ) }) } else { - Err(Error { - status: Status::InvalidArg, - reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), - }) + Err(Error::new( + Status::InvalidArg, + "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), + )) } } } @@ -807,10 +811,10 @@ impl Env { Box::from_raw(unknown_tagged_object as *mut TaggedObject); Ok(()) } else { - Err(Error { - status: Status::InvalidArg, - reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), - }) + Err(Error::new( + Status::InvalidArg, + "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), + )) } } } @@ -905,15 +909,17 @@ impl Env { let type_id = unknown_tagged_object as *const TypeId; if *type_id == TypeId::of::() { let tagged_object = unknown_tagged_object as *mut TaggedObject; - (*tagged_object).object.as_mut().ok_or(Error { - status: Status::InvalidArg, - reason: "nothing attach to js_external".to_owned(), + (*tagged_object).object.as_mut().ok_or_else(|| { + Error::new( + Status::InvalidArg, + "nothing attach to js_external".to_owned(), + ) }) } else { - Err(Error { - status: Status::InvalidArg, - reason: "T on get_value_external is not the type of wrapped object".to_owned(), - }) + Err(Error::new( + Status::InvalidArg, + "T on get_value_external is not the type of wrapped object".to_owned(), + )) } } } @@ -1103,15 +1109,17 @@ impl Env { } if *type_id == TypeId::of::() { let tagged_object = unknown_tagged_object as *mut TaggedObject; - (*tagged_object).object.as_mut().map(Some).ok_or(Error { - status: Status::InvalidArg, - reason: "Invalid argument, nothing attach to js_object".to_owned(), + (*tagged_object).object.as_mut().map(Some).ok_or_else(|| { + Error::new( + Status::InvalidArg, + "Invalid argument, nothing attach to js_object".to_owned(), + ) }) } else { - Err(Error { - status: Status::InvalidArg, - reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), - }) + Err(Error::new( + Status::InvalidArg, + "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), + )) } } } diff --git a/crates/napi/src/error.rs b/crates/napi/src/error.rs index ec45ba2d..714b2465 100644 --- a/crates/napi/src/error.rs +++ b/crates/napi/src/error.rs @@ -23,8 +23,14 @@ pub type Result = std::result::Result; pub struct Error { pub status: Status, pub reason: String, + // Convert raw `JsError` into Error + // Only be used in `async fn(p: Promise)` scenario + pub(crate) maybe_raw: sys::napi_ref, } +unsafe impl Send for Error {} +unsafe impl Sync for Error {} + impl error::Error for Error {} #[cfg(feature = "serde-json")] @@ -48,6 +54,16 @@ impl From for Error { } } +impl From for Error { + fn from(value: sys::napi_ref) -> Self { + Self { + status: Status::InvalidArg, + reason: "".to_string(), + maybe_raw: value, + } + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if !self.reason.is_empty() { @@ -60,13 +76,18 @@ impl fmt::Display for Error { impl Error { pub fn new(status: Status, reason: String) -> Self { - Error { status, reason } + Error { + status, + reason, + maybe_raw: ptr::null_mut(), + } } pub fn from_status(status: Status) -> Self { Error { status, reason: "".to_owned(), + maybe_raw: ptr::null_mut(), } } @@ -74,6 +95,7 @@ impl Error { Error { status: Status::GenericFailure, reason, + maybe_raw: ptr::null_mut(), } } } @@ -83,6 +105,7 @@ impl From for Error { Error { status: Status::GenericFailure, reason: format!("{}", error), + maybe_raw: ptr::null_mut(), } } } @@ -92,6 +115,7 @@ impl From for Error { Error { status: Status::GenericFailure, reason: format!("{}", error), + maybe_raw: ptr::null_mut(), } } } @@ -163,8 +187,10 @@ macro_rules! impl_object_methods { pub unsafe fn throw_into(self, env: sys::napi_env) { #[cfg(debug_assertions)] let reason = self.0.reason.clone(); - #[cfg(debug_assertions)] let status = self.0.status; + if status == Status::PendingException { + return; + } let js_error = self.into_value(env); #[cfg(debug_assertions)] let throw_status = sys::napi_throw(env, js_error); diff --git a/crates/napi/src/promise.rs b/crates/napi/src/promise.rs index 2d09d60e..661ed1ce 100644 --- a/crates/napi/src/promise.rs +++ b/crates/napi/src/promise.rs @@ -1,6 +1,7 @@ +use std::ffi::CString; use std::future::Future; use std::marker::PhantomData; -use std::os::raw::{c_char, c_void}; +use std::os::raw::c_void; use std::ptr; use crate::{check_status, sys, JsError, Result}; @@ -12,7 +13,6 @@ pub struct FuturePromise Result, - _value: PhantomData, } unsafe impl Result> Send @@ -23,26 +23,20 @@ unsafe impl Result> Send impl Result> FuturePromise { - pub fn new(env: sys::napi_env, dererred: sys::napi_deferred, resolver: Resolver) -> Result { + pub fn new(env: sys::napi_env, deferred: sys::napi_deferred, resolver: Resolver) -> Result { let mut async_resource_name = ptr::null_mut(); - let s = "napi_resolve_promise_from_future"; + let s = CString::new("napi_resolve_promise_from_future")?; check_status!(unsafe { - sys::napi_create_string_utf8( - env, - s.as_ptr() as *const c_char, - s.len(), - &mut async_resource_name, - ) + sys::napi_create_string_utf8(env, s.as_ptr(), 32, &mut async_resource_name) })?; Ok(FuturePromise { - deferred: dererred, + deferred, resolver, env, tsfn: ptr::null_mut(), async_resource_name, _data: PhantomData, - _value: PhantomData, }) } @@ -83,7 +77,7 @@ pub(crate) async fn resolve_from_future { - let status = sys::napi_reject_deferred(env, deferred, JsError::from(e).into_value(env)); + let status = sys::napi_reject_deferred( + env, + deferred, + if e.maybe_raw.is_null() { + JsError::from(e).into_value(env) + } else { + let mut err = ptr::null_mut(); + let get_err_status = sys::napi_get_reference_value(env, e.maybe_raw, &mut err); + debug_assert!( + get_err_status == sys::Status::napi_ok, + "Get Error from Reference failed" + ); + let delete_reference_status = sys::napi_delete_reference(env, e.maybe_raw); + debug_assert!( + delete_reference_status == sys::Status::napi_ok, + "Delete Error Reference failed" + ); + err + }, + ); debug_assert!(status == sys::Status::napi_ok, "Reject promise failed"); } }; diff --git a/examples/napi/__test__/typegen.spec.ts.md b/examples/napi/__test__/typegen.spec.ts.md index 3142411b..f8b7e4d6 100644 --- a/examples/napi/__test__/typegen.spec.ts.md +++ b/examples/napi/__test__/typegen.spec.ts.md @@ -34,6 +34,7 @@ Generated by [AVA](https://avajs.dev). export function fibonacci(n: number): number␊ export function listObjKeys(obj: object): Array␊ export function createObj(): object␊ + export function asyncPlus100(p: Promise): Promise␊ interface PackageJson {␊ name: string␊ version: string␊ diff --git a/examples/napi/__test__/typegen.spec.ts.snap b/examples/napi/__test__/typegen.spec.ts.snap index 3fde6368903c6d9afcf9b3799747e5aac4a93387..31e68ac416b8426170c562dc603077c7db5bb4a1 100644 GIT binary patch literal 1013 zcmVRzV`M7}ANLuQi>$G*%0lo^C=8 zio6DtMILkA1ezlEKs+~VSkx`0hLYwqxe^)go0()}LhUGm(i7CV)hZW;zdS33!%Y^c zQ~I*Vthk>lLUmf1sp+Di7c{g zs{tYw|TWERdsVe1#&KgUL2S+VNQHYhwIF(EDc(Ys& zDdTVxzk{cJfQP1;=N3c9TMeOvyp3o%F5!)k2|U4tO5qiBLpY}vg2gMl^G;u8$50k5-Mh-<}<5HTC60+Nva{lbeZ?$OFxyWyXAot*x7VTy|jz@?E&C zo`xb;k}@WU+z!r!v3Or^_{kiA{-Qye9#u>BZ!&oMHsvJ6LGOON0+|}pRz|Jx{=7(pmslk`Y z$G7GrpOG=2MRQ<`>rS+))k2BLl5g$Ul^B1tm92W>))I6vr7tEz(?3M?Tv$0&B(wjm$<33N-OAh1h=i7YhIYd#LnE literal 998 zcmVe)E0gaApWW5Nric-hT1t$G?Ai_uH3W{qxs{ z_rGri?)S%b5IhM!^Y+%K!RMd+x{Y^Zr7iRe$*oWl254U?otm%{Lsx4yeQK;0a&Wu~ z*(>rIP!@TDx(PHz?t!#!mayboS`87~Da2=TU%X$8f7eN(An&HGoZmbFGzzH&DnF`$A$G0WIuR%(Nnl zEL$sp!VXbid!H-a^WhAI92a(hX%_a@Yqo`!=box^p8G7ZwCv!ZMU;$0sSKHXl&70T z-De!(DtQg}+W@ysnTHmk!?i+W5pUzPNJDrbWD56irUvi~b|W}tmW;g!&I&#vYabpK z@TY~b$`#JJ8{NZ|#1c;_?cih#HM+Tj2rg8i$vTYSTHKLw6v4R|44uN)5%h%j{f1mn zI+9&jaYmB4QI`k%A;yd1?YC!#T1|X9g>6}qv+>o~S>Y|q;>U>j6k8fM>$t4K66C9J zT|V_iq9o&7gmOJN6UNegy}+pnRq{Fp`XIM!J6~DDcjK#g73Uoo3T>TG9gW9C`I`^5D{4FIcRH5eO5!O7&x=kXF^j$x62XuN88^! z+8q8JvqJ7|f}c<$P=)O#G=x~ImmZ9)}+5d-PU$D^F`CXG> zsbPJ*m`&Fogq@uox&}HpCC<;k_?AyrTP@oL*Ue)C;K~2yqx&ykaQ_N=tHSV%83SD@ za+dija>`0(Qy5VnjeR$zS-L;Ese)VDz~e%+P;(u^8B15peZU={hIPg6RR~4<@t$F) zR-LArT;`k%70bmUwv-_rYV%&GGu!IJ;=xWbHp2ar=giQ%;?K9VNt}iwoW?U?LgP-n z=+#0=$dYeq*`*lX*-EP(xTOG{G0Vb~KDBa~c@r>4s(`e1>SZ1GeMK}E%#lx5wZNKK UuMyd}K|vD#0*y4AjMWMN06#$Q`2YX_ diff --git a/examples/napi/__test__/values.spec.ts b/examples/napi/__test__/values.spec.ts index f5f40870..c14d46b0 100644 --- a/examples/napi/__test__/values.spec.ts +++ b/examples/napi/__test__/values.spec.ts @@ -40,6 +40,7 @@ import { createBigIntI64, callThreadsafeFunction, threadsafeFunctionThrowError, + asyncPlus100, } from '../' test('number', (t) => { @@ -238,10 +239,9 @@ BigIntTest('create BigInt i64', (t) => { t.is(createBigIntI64(), BigInt(100)) }) -const ThreadsafeFunctionTest = - Number(process.versions.napi) >= 4 ? test : test.skip +const Napi4Test = Number(process.versions.napi) >= 4 ? test : test.skip -ThreadsafeFunctionTest('call thread safe function', (t) => { +Napi4Test('call thread safe function', (t) => { let i = 0 let value = 0 return new Promise((resolve) => { @@ -260,10 +260,26 @@ ThreadsafeFunctionTest('call thread safe function', (t) => { }) }) -ThreadsafeFunctionTest('throw error from thread safe function', async (t) => { +Napi4Test('throw error from thread safe function', async (t) => { const throwPromise = new Promise((_, reject) => { threadsafeFunctionThrowError(reject) }) const err = await t.throwsAsync(throwPromise) t.is(err.message, 'ThrowFromNative') }) + +Napi4Test('await Promise in rust', async (t) => { + const fx = 20 + const result = await asyncPlus100( + new Promise((resolve) => { + setTimeout(() => resolve(fx), 50) + }), + ) + t.is(result, fx + 100) +}) + +Napi4Test('Promise should reject raw error in rust', async (t) => { + const fxError = new Error('What is Happy Planet') + const err = await t.throwsAsync(() => asyncPlus100(Promise.reject(fxError))) + t.is(err, fxError) +}) diff --git a/examples/napi/index.d.ts b/examples/napi/index.d.ts index 5dd08dfa..c1fda3e2 100644 --- a/examples/napi/index.d.ts +++ b/examples/napi/index.d.ts @@ -24,6 +24,7 @@ export function add(a: number, b: number): number export function fibonacci(n: number): number export function listObjKeys(obj: object): Array export function createObj(): object +export function asyncPlus100(p: Promise): Promise interface PackageJson { name: string version: string diff --git a/examples/napi/src/lib.rs b/examples/napi/src/lib.rs index 15f0d80d..8c75aa37 100644 --- a/examples/napi/src/lib.rs +++ b/examples/napi/src/lib.rs @@ -15,6 +15,7 @@ mod error; mod nullable; mod number; mod object; +mod promise; mod serde; mod string; mod task; diff --git a/examples/napi/src/promise.rs b/examples/napi/src/promise.rs new file mode 100644 index 00000000..15ce4350 --- /dev/null +++ b/examples/napi/src/promise.rs @@ -0,0 +1,7 @@ +use napi::bindgen_prelude::*; + +#[napi] +pub async fn async_plus_100(p: Promise) -> Result { + let v = p.await?; + Ok(v + 100) +}