Merge pull request #849 from napi-rs/promise

feat(napi): await Promise<T> in async fn
This commit is contained in:
LongYinan 2021-11-15 17:05:56 +08:00 committed by GitHub
commit 1ef7de4dd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 357 additions and 84 deletions

View file

@ -109,7 +109,7 @@ impl NapiEnum {
define_properties.push(quote! {
{
let name = CString::new(#name_lit).unwrap();
let name = CString::new(#name_lit)?;
napi::bindgen_prelude::check_status!(
napi::bindgen_prelude::sys::napi_set_named_property(env, obj_ptr, name.as_ptr(), i32::to_napi_value(env, #val_lit)?),
"Failed to defined enum `{}`",

View file

@ -155,6 +155,8 @@ pub fn ty_to_ts_type(ty: &Type, is_return_ty: bool) -> String {
.with(|c| c.borrow_mut().get(rust_ty.as_str()).cloned())
{
ts_ty = Some(t);
} else if rust_ty == "Promise" {
ts_ty = Some(format!("Promise<{}>", args.first().unwrap()));
} else {
// there should be runtime registered type in else
ts_ty = Some(rust_ty);

View file

@ -59,11 +59,12 @@ pub fn run<T: Task>(
napi_async_work: ptr::null_mut(),
status: task_status.clone(),
}));
let async_work_name = CString::new("napi_rs_async_work")?;
check_status!(unsafe {
sys::napi_create_async_work(
env,
raw_resource,
CString::new("napi_rs_async_work")?.as_ptr() as *mut _,
async_work_name.as_ptr() as *mut _,
Some(execute::<T> as unsafe extern "C" fn(env: sys::napi_env, data: *mut c_void)),
Some(
complete::<T>

View file

@ -13,6 +13,7 @@ mod map;
mod nil;
mod number;
mod object;
mod promise;
#[cfg(feature = "serde-json")]
mod serde;
mod string;
@ -27,6 +28,7 @@ pub use either::*;
pub use function::*;
pub use nil::*;
pub use object::*;
pub use promise::*;
pub use string::*;
pub use task::*;

View file

@ -0,0 +1,170 @@
use std::ffi::{c_void, CString};
use std::future;
use std::pin::Pin;
use std::ptr;
use std::task::{Context, Poll};
use tokio::sync::oneshot::{channel, Receiver, Sender};
use crate::{check_status, Error, Result, Status};
use super::FromNapiValue;
pub struct Promise<T: FromNapiValue> {
value: Pin<Box<Receiver<*mut Result<T>>>>,
}
unsafe impl<T: FromNapiValue> Send for Promise<T> {}
unsafe impl<T: FromNapiValue> Sync for Promise<T> {}
impl<T: FromNapiValue> FromNapiValue for Promise<T> {
unsafe fn from_napi_value(
env: napi_sys::napi_env,
napi_val: napi_sys::napi_value,
) -> crate::Result<Self> {
let mut then = ptr::null_mut();
let then_c_string = CString::new("then")?;
check_status!(
napi_sys::napi_get_named_property(env, napi_val, then_c_string.as_ptr(), &mut then,),
"Failed to get then function"
)?;
let mut promise_after_then = ptr::null_mut();
let mut then_js_cb = ptr::null_mut();
let (tx, rx) = channel();
let tx_ptr = Box::into_raw(Box::new(tx));
check_status!(
napi_sys::napi_create_function(
env,
then_c_string.as_ptr(),
4,
Some(then_callback::<T>),
tx_ptr as *mut _,
&mut then_js_cb,
),
"Failed to create then callback"
)?;
check_status!(
napi_sys::napi_call_function(
env,
napi_val,
then,
1,
[then_js_cb].as_ptr(),
&mut promise_after_then,
),
"Failed to call then method"
)?;
let mut catch = ptr::null_mut();
let catch_c_string = CString::new("catch")?;
check_status!(
napi_sys::napi_get_named_property(
env,
promise_after_then,
catch_c_string.as_ptr(),
&mut catch
),
"Failed to get then function"
)?;
let mut catch_js_cb = ptr::null_mut();
check_status!(
napi_sys::napi_create_function(
env,
catch_c_string.as_ptr(),
5,
Some(catch_callback::<T>),
tx_ptr as *mut c_void,
&mut catch_js_cb
),
"Failed to create catch callback"
)?;
check_status!(
napi_sys::napi_call_function(
env,
promise_after_then,
catch,
1,
[catch_js_cb].as_ptr(),
ptr::null_mut()
),
"Failed to call catch method"
)?;
Ok(Promise {
value: Box::pin(rx),
})
}
}
impl<T: FromNapiValue> future::Future for Promise<T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.value.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(v) => Poll::Ready(
v.map_err(|e| Error::new(Status::GenericFailure, format!("{}", e)))
.and_then(|v| unsafe { *Box::from_raw(v) }.map_err(Error::from)),
),
}
}
}
unsafe extern "C" fn then_callback<T: FromNapiValue>(
env: napi_sys::napi_env,
info: napi_sys::napi_callback_info,
) -> napi_sys::napi_value {
let mut data = ptr::null_mut();
let mut resolved_value: [napi_sys::napi_value; 1] = [ptr::null_mut()];
let mut this = ptr::null_mut();
let get_cb_status = napi_sys::napi_get_cb_info(
env,
info,
&mut 1,
resolved_value.as_mut_ptr(),
&mut this,
&mut data,
);
debug_assert!(
get_cb_status == napi_sys::Status::napi_ok,
"Get callback info from Promise::then failed"
);
let resolve_value_t = Box::new(T::from_napi_value(env, resolved_value[0]));
let sender = Box::from_raw(data as *mut Sender<*mut Result<T>>);
sender
.send(Box::into_raw(resolve_value_t))
.expect("Send Promise resolved value error");
this
}
unsafe extern "C" fn catch_callback<T: FromNapiValue>(
env: napi_sys::napi_env,
info: napi_sys::napi_callback_info,
) -> napi_sys::napi_value {
let mut data = ptr::null_mut();
let mut rejected_value: [napi_sys::napi_value; 1] = [ptr::null_mut()];
let mut this = ptr::null_mut();
let mut argc = 1;
let get_cb_status = napi_sys::napi_get_cb_info(
env,
info,
&mut argc,
rejected_value.as_mut_ptr(),
&mut this,
&mut data,
);
debug_assert!(
get_cb_status == napi_sys::Status::napi_ok,
"Get callback info from Promise::catch failed"
);
let rejected_value = rejected_value[0];
let mut error_ref = ptr::null_mut();
let create_ref_status = napi_sys::napi_create_reference(env, rejected_value, 1, &mut error_ref);
debug_assert!(
create_ref_status == napi_sys::Status::napi_ok,
"Create Error reference failed"
);
let sender = Box::from_raw(data as *mut Sender<*mut Result<T>>);
sender
.send(Box::into_raw(Box::new(Err(Error::from(error_ref)))))
.expect("Send Promise resolved value error");
this
}

View file

@ -47,10 +47,10 @@ impl<'env> CallContext<'env> {
pub fn get<ArgType: NapiValue>(&self, index: usize) -> Result<ArgType> {
if index >= self.arg_len() {
Err(Error {
status: Status::GenericFailure,
reason: "Arguments index out of range".to_owned(),
})
Err(Error::new(
Status::GenericFailure,
"Arguments index out of range".to_owned(),
))
} else {
Ok(unsafe { ArgType::from_raw_unchecked(self.env.0, self.args[index]) })
}
@ -58,10 +58,10 @@ impl<'env> CallContext<'env> {
pub fn try_get<ArgType: NapiValue>(&self, index: usize) -> Result<Either<ArgType, JsUndefined>> {
if index >= self.arg_len() {
Err(Error {
status: Status::GenericFailure,
reason: "Arguments index out of range".to_owned(),
})
Err(Error::new(
Status::GenericFailure,
"Arguments index out of range".to_owned(),
))
} else if index < self.length {
unsafe { ArgType::from_raw(self.env.0, self.args[index]) }.map(Either::A)
} else {

View file

@ -629,42 +629,39 @@ impl Env {
/// This API throws a JavaScript Error with the text provided.
pub fn throw_error(&self, msg: &str, code: Option<&str>) -> Result<()> {
let code = code.and_then(|s| CString::new(s).ok());
let msg = CString::new(msg)?;
check_status!(unsafe {
sys::napi_throw_error(
self.0,
match code {
Some(s) => CString::new(s)?.as_ptr(),
None => ptr::null_mut(),
},
CString::new(msg)?.as_ptr(),
code.map(|s| s.as_ptr()).unwrap_or(ptr::null_mut()),
msg.as_ptr(),
)
})
}
/// This API throws a JavaScript RangeError with the text provided.
pub fn throw_range_error(&self, msg: &str, code: Option<&str>) -> Result<()> {
let code = code.and_then(|s| CString::new(s).ok());
let msg = CString::new(msg)?;
check_status!(unsafe {
sys::napi_throw_range_error(
self.0,
match code {
Some(s) => CString::new(s)?.as_ptr(),
None => ptr::null_mut(),
},
CString::new(msg)?.as_ptr(),
code.map(|s| s.as_ptr()).unwrap_or(ptr::null_mut()),
msg.as_ptr(),
)
})
}
/// This API throws a JavaScript TypeError with the text provided.
pub fn throw_type_error(&self, msg: &str, code: Option<&str>) -> Result<()> {
let code = code.and_then(|s| CString::new(s).ok());
let msg = CString::new(msg)?;
check_status!(unsafe {
sys::napi_throw_type_error(
self.0,
match code {
Some(s) => CString::new(s)?.as_ptr(),
None => ptr::null_mut(),
},
CString::new(msg)?.as_ptr(),
code.map(|s| s.as_ptr()).unwrap_or(ptr::null_mut()),
msg.as_ptr(),
)
})
}
@ -714,11 +711,11 @@ impl Env {
.iter()
.map(|prop| prop.raw())
.collect::<Vec<sys::napi_property_descriptor>>();
let c_name = CString::new(name)?;
check_status!(unsafe {
sys::napi_define_class(
self.0,
name.as_ptr() as *const c_char,
c_name.as_ptr() as *const c_char,
name.len(),
Some(constructor_cb),
ptr::null_mut(),
@ -756,15 +753,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error {
status: Status::InvalidArg,
reason: "Invalid argument, nothing attach to js_object".to_owned(),
(*tagged_object).object.as_mut().ok_or_else(|| {
Error::new(
Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
})
} else {
Err(Error {
status: Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
})
Err(Error::new(
Status::InvalidArg,
"Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
))
}
}
}
@ -781,15 +780,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error {
status: Status::InvalidArg,
reason: "Invalid argument, nothing attach to js_object".to_owned(),
(*tagged_object).object.as_mut().ok_or_else(|| {
Error::new(
Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
})
} else {
Err(Error {
status: Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
})
Err(Error::new(
Status::InvalidArg,
"Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
))
}
}
}
@ -807,10 +808,10 @@ impl Env {
Box::from_raw(unknown_tagged_object as *mut TaggedObject<T>);
Ok(())
} else {
Err(Error {
status: Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
})
Err(Error::new(
Status::InvalidArg,
"Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
))
}
}
}
@ -905,15 +906,17 @@ impl Env {
let type_id = unknown_tagged_object as *const TypeId;
if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().ok_or(Error {
status: Status::InvalidArg,
reason: "nothing attach to js_external".to_owned(),
(*tagged_object).object.as_mut().ok_or_else(|| {
Error::new(
Status::InvalidArg,
"nothing attach to js_external".to_owned(),
)
})
} else {
Err(Error {
status: Status::InvalidArg,
reason: "T on get_value_external is not the type of wrapped object".to_owned(),
})
Err(Error::new(
Status::InvalidArg,
"T on get_value_external is not the type of wrapped object".to_owned(),
))
}
}
}
@ -1103,15 +1106,17 @@ impl Env {
}
if *type_id == TypeId::of::<T>() {
let tagged_object = unknown_tagged_object as *mut TaggedObject<T>;
(*tagged_object).object.as_mut().map(Some).ok_or(Error {
status: Status::InvalidArg,
reason: "Invalid argument, nothing attach to js_object".to_owned(),
(*tagged_object).object.as_mut().map(Some).ok_or_else(|| {
Error::new(
Status::InvalidArg,
"Invalid argument, nothing attach to js_object".to_owned(),
)
})
} else {
Err(Error {
status: Status::InvalidArg,
reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
})
Err(Error::new(
Status::InvalidArg,
"Invalid argument, T on unrwap is not the type of wrapped object".to_owned(),
))
}
}
}

View file

@ -23,8 +23,14 @@ pub type Result<T> = std::result::Result<T, Error>;
pub struct Error {
pub status: Status,
pub reason: String,
// Convert raw `JsError` into Error
// Only be used in `async fn(p: Promise<T>)` scenario
pub(crate) maybe_raw: sys::napi_ref,
}
unsafe impl Send for Error {}
unsafe impl Sync for Error {}
impl error::Error for Error {}
#[cfg(feature = "serde-json")]
@ -48,6 +54,16 @@ impl From<SerdeJSONError> for Error {
}
}
impl From<sys::napi_ref> for Error {
fn from(value: sys::napi_ref) -> Self {
Self {
status: Status::InvalidArg,
reason: "".to_string(),
maybe_raw: value,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.reason.is_empty() {
@ -60,13 +76,18 @@ impl fmt::Display for Error {
impl Error {
pub fn new(status: Status, reason: String) -> Self {
Error { status, reason }
Error {
status,
reason,
maybe_raw: ptr::null_mut(),
}
}
pub fn from_status(status: Status) -> Self {
Error {
status,
reason: "".to_owned(),
maybe_raw: ptr::null_mut(),
}
}
@ -74,6 +95,7 @@ impl Error {
Error {
status: Status::GenericFailure,
reason,
maybe_raw: ptr::null_mut(),
}
}
}
@ -83,6 +105,7 @@ impl From<std::ffi::NulError> for Error {
Error {
status: Status::GenericFailure,
reason: format!("{}", error),
maybe_raw: ptr::null_mut(),
}
}
}
@ -92,6 +115,7 @@ impl From<std::io::Error> for Error {
Error {
status: Status::GenericFailure,
reason: format!("{}", error),
maybe_raw: ptr::null_mut(),
}
}
}
@ -163,8 +187,10 @@ macro_rules! impl_object_methods {
pub unsafe fn throw_into(self, env: sys::napi_env) {
#[cfg(debug_assertions)]
let reason = self.0.reason.clone();
#[cfg(debug_assertions)]
let status = self.0.status;
if status == Status::PendingException {
return;
}
let js_error = self.into_value(env);
#[cfg(debug_assertions)]
let throw_status = sys::napi_throw(env, js_error);

View file

@ -306,7 +306,7 @@ macro_rules! impl_object_methods {
pub fn create_named_method(&mut self, name: &str, function: Callback) -> Result<()> {
let mut js_function = ptr::null_mut();
let len = name.len();
let name = CString::new(name.as_bytes())?;
let name = CString::new(name)?;
check_status!(unsafe {
sys::napi_create_function(
self.0.env,

View file

@ -155,6 +155,35 @@ macro_rules! assert_type_of {
};
}
#[allow(dead_code)]
#[cfg(debug_assertions)]
pub(crate) unsafe fn log_js_value<V: AsRef<[sys::napi_value]>>(
// `info`, `log`, `warning` or `error`
method: &str,
env: sys::napi_env,
values: V,
) {
use std::ffi::CString;
use std::ptr;
let mut g = ptr::null_mut();
sys::napi_get_global(env, &mut g);
let mut console = ptr::null_mut();
let console_c_string = CString::new("console").unwrap();
let method_c_string = CString::new(method).unwrap();
sys::napi_get_named_property(env, g, console_c_string.as_ptr(), &mut console);
let mut method_js_fn = ptr::null_mut();
sys::napi_get_named_property(env, console, method_c_string.as_ptr(), &mut method_js_fn);
sys::napi_call_function(
env,
console,
method_js_fn,
values.as_ref().len(),
values.as_ref().as_ptr(),
ptr::null_mut(),
);
}
pub mod bindgen_prelude {
#[cfg(feature = "compat-mode")]
pub use crate::bindgen_runtime::register_module_exports;

View file

@ -1,6 +1,7 @@
use std::ffi::CString;
use std::future::Future;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_void};
use std::os::raw::c_void;
use std::ptr;
use crate::{check_status, sys, JsError, Result};
@ -12,7 +13,6 @@ pub struct FuturePromise<Data, Resolver: FnOnce(sys::napi_env, Data) -> Result<s
async_resource_name: sys::napi_value,
resolver: Resolver,
_data: PhantomData<Data>,
_value: PhantomData<sys::napi_value>,
}
unsafe impl<T, F: FnOnce(sys::napi_env, T) -> Result<sys::napi_value>> Send
@ -23,26 +23,20 @@ unsafe impl<T, F: FnOnce(sys::napi_env, T) -> Result<sys::napi_value>> Send
impl<Data, Resolver: FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>>
FuturePromise<Data, Resolver>
{
pub fn new(env: sys::napi_env, dererred: sys::napi_deferred, resolver: Resolver) -> Result<Self> {
pub fn new(env: sys::napi_env, deferred: sys::napi_deferred, resolver: Resolver) -> Result<Self> {
let mut async_resource_name = ptr::null_mut();
let s = "napi_resolve_promise_from_future";
let s = CString::new("napi_resolve_promise_from_future")?;
check_status!(unsafe {
sys::napi_create_string_utf8(
env,
s.as_ptr() as *const c_char,
s.len(),
&mut async_resource_name,
)
sys::napi_create_string_utf8(env, s.as_ptr(), 32, &mut async_resource_name)
})?;
Ok(FuturePromise {
deferred: dererred,
deferred,
resolver,
env,
tsfn: ptr::null_mut(),
async_resource_name,
_data: PhantomData,
_value: PhantomData,
})
}
@ -83,7 +77,7 @@ pub(crate) async fn resolve_from_future<Data: Send, Fut: Future<Output = Result<
check_status!(unsafe {
sys::napi_call_threadsafe_function(
tsfn_value.0,
Box::into_raw(Box::from(val)) as *mut Data as *mut c_void,
Box::into_raw(Box::from(val)) as *mut c_void,
sys::napi_threadsafe_function_call_mode::napi_tsfn_nonblocking,
)
})
@ -117,7 +111,26 @@ unsafe extern "C" fn call_js_cb<
debug_assert!(status == sys::Status::napi_ok, "Resolve promise failed");
}
Err(e) => {
let status = sys::napi_reject_deferred(env, deferred, JsError::from(e).into_value(env));
let status = sys::napi_reject_deferred(
env,
deferred,
if e.maybe_raw.is_null() {
JsError::from(e).into_value(env)
} else {
let mut err = ptr::null_mut();
let get_err_status = sys::napi_get_reference_value(env, e.maybe_raw, &mut err);
debug_assert!(
get_err_status == sys::Status::napi_ok,
"Get Error from Reference failed"
);
let delete_reference_status = sys::napi_delete_reference(env, e.maybe_raw);
debug_assert!(
delete_reference_status == sys::Status::napi_ok,
"Delete Error Reference failed"
);
err
},
);
debug_assert!(status == sys::Status::napi_ok, "Reject promise failed");
}
};

View file

@ -19,10 +19,9 @@ impl TryFrom<sys::napi_node_version> for NodeVersion {
minor: value.minor,
patch: value.patch,
release: unsafe {
CStr::from_ptr(value.release).to_str().map_err(|_| Error {
status: Status::StringExpected,
reason: "Invalid release name".to_owned(),
})?
CStr::from_ptr(value.release)
.to_str()
.map_err(|_| Error::new(Status::StringExpected, "Invalid release name".to_owned()))?
},
})
}

View file

@ -34,6 +34,7 @@ Generated by [AVA](https://avajs.dev).
export function fibonacci(n: number): number␊
export function listObjKeys(obj: object): Array<string>
export function createObj(): object␊
export function asyncPlus100(p: Promise<number>): Promise<number>
interface PackageJson {␊
name: string␊
version: string␊

View file

@ -40,6 +40,7 @@ import {
createBigIntI64,
callThreadsafeFunction,
threadsafeFunctionThrowError,
asyncPlus100,
} from '../'
test('number', (t) => {
@ -238,10 +239,9 @@ BigIntTest('create BigInt i64', (t) => {
t.is(createBigIntI64(), BigInt(100))
})
const ThreadsafeFunctionTest =
Number(process.versions.napi) >= 4 ? test : test.skip
const Napi4Test = Number(process.versions.napi) >= 4 ? test : test.skip
ThreadsafeFunctionTest('call thread safe function', (t) => {
Napi4Test('call thread safe function', (t) => {
let i = 0
let value = 0
return new Promise((resolve) => {
@ -260,10 +260,26 @@ ThreadsafeFunctionTest('call thread safe function', (t) => {
})
})
ThreadsafeFunctionTest('throw error from thread safe function', async (t) => {
Napi4Test('throw error from thread safe function', async (t) => {
const throwPromise = new Promise((_, reject) => {
threadsafeFunctionThrowError(reject)
})
const err = await t.throwsAsync(throwPromise)
t.is(err.message, 'ThrowFromNative')
})
Napi4Test('await Promise in rust', async (t) => {
const fx = 20
const result = await asyncPlus100(
new Promise((resolve) => {
setTimeout(() => resolve(fx), 50)
}),
)
t.is(result, fx + 100)
})
Napi4Test('Promise should reject raw error in rust', async (t) => {
const fxError = new Error('What is Happy Planet')
const err = await t.throwsAsync(() => asyncPlus100(Promise.reject(fxError)))
t.is(err, fxError)
})

View file

@ -24,6 +24,7 @@ export function add(a: number, b: number): number
export function fibonacci(n: number): number
export function listObjKeys(obj: object): Array<string>
export function createObj(): object
export function asyncPlus100(p: Promise<number>): Promise<number>
interface PackageJson {
name: string
version: string

View file

@ -15,6 +15,7 @@ mod error;
mod nullable;
mod number;
mod object;
mod promise;
mod serde;
mod string;
mod task;

View file

@ -0,0 +1,7 @@
use napi::bindgen_prelude::*;
#[napi]
pub async fn async_plus_100(p: Promise<u32>) -> Result<u32> {
let v = p.await?;
Ok(v + 100)
}