diff --git a/crates/napi/src/env.rs b/crates/napi/src/env.rs index 0a54e929..6b957871 100644 --- a/crates/napi/src/env.rs +++ b/crates/napi/src/env.rs @@ -1,10 +1,14 @@ use std::any::TypeId; use std::convert::TryInto; use std::ffi::CString; +#[cfg(all(feature = "tokio_rt", feature = "napi4"))] +use std::future::Future; use std::mem; use std::os::raw::{c_char, c_void}; use std::ptr; +#[cfg(all(feature = "tokio_rt", feature = "napi4"))] +use crate::bindgen_runtime::ToNapiValue; use crate::{ async_work::{self, AsyncWorkPromise}, check_status, @@ -28,8 +32,6 @@ use crate::JsError; use serde::de::DeserializeOwned; #[cfg(all(feature = "serde-json"))] use serde::Serialize; -#[cfg(all(feature = "tokio_rt", feature = "napi4"))] -use std::future::Future; pub type Callback = unsafe extern "C" fn(sys::napi_env, sys::napi_callback_info) -> sys::napi_value; @@ -1038,7 +1040,7 @@ impl Env { #[cfg(feature = "napi4")] pub fn create_threadsafe_function< T: Send, - V: NapiRaw, + V: ToNapiValue, R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, >( &self, @@ -1052,7 +1054,7 @@ impl Env { #[cfg(all(feature = "tokio_rt", feature = "napi4"))] pub fn execute_tokio_future< T: 'static + Send, - V: 'static + NapiValue, + V: 'static + ToNapiValue, F: 'static + Send + Future>, R: 'static + Send + Sync + FnOnce(&mut Env, T) -> Result, >( @@ -1063,7 +1065,7 @@ impl Env { use crate::tokio_runtime; let promise = tokio_runtime::execute_tokio_future(self.0, fut, |env, val| unsafe { - resolver(&mut Env::from_raw(env), val).map(|v| v.raw()) + resolver(&mut Env::from_raw(env), val).and_then(|v| ToNapiValue::to_napi_value(env, v)) })?; Ok(unsafe { JsObject::from_raw_unchecked(self.0, promise) }) diff --git a/crates/napi/src/js_values/function.rs b/crates/napi/src/js_values/function.rs index 12c4f4b9..f6dfa5f6 100644 --- a/crates/napi/src/js_values/function.rs +++ b/crates/napi/src/js_values/function.rs @@ -3,7 +3,10 @@ use std::ptr; use super::Value; use crate::bindgen_runtime::TypeName; #[cfg(feature = "napi4")] -use crate::threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction}; +use crate::{ + bindgen_runtime::ToNapiValue, + threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction}, +}; use crate::{check_status, ValueType}; use crate::{sys, Env, Error, JsObject, JsUnknown, NapiRaw, NapiValue, Result, Status}; @@ -127,7 +130,7 @@ impl JsFunction { ) -> Result> where T: 'static, - V: NapiRaw, + V: ToNapiValue, F: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, ES: crate::threadsafe_function::ErrorStrategy::T, { diff --git a/crates/napi/src/threadsafe_function.rs b/crates/napi/src/threadsafe_function.rs index 025028ad..3f0dbb9a 100644 --- a/crates/napi/src/threadsafe_function.rs +++ b/crates/napi/src/threadsafe_function.rs @@ -8,7 +8,8 @@ use std::ptr; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use crate::{check_status, sys, Env, Error, JsError, NapiRaw, Result, Status}; +use crate::bindgen_runtime::ToNapiValue; +use crate::{check_status, sys, Env, Error, JsError, Result, Status}; /// ThreadSafeFunction Context object /// the `value` is the value passed to `call` method @@ -176,7 +177,7 @@ impl ThreadsafeFunction { /// See [napi_create_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_create_threadsafe_function) /// for more information. pub(crate) fn create< - V: NapiRaw, + V: ToNapiValue, R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, >( env: sys::napi_env, @@ -330,7 +331,7 @@ unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) { aborted.store(true, Ordering::SeqCst); } -unsafe extern "C" fn thread_finalize_cb( +unsafe extern "C" fn thread_finalize_cb( _raw_env: sys::napi_env, finalize_data: *mut c_void, _finalize_hint: *mut c_void, @@ -341,7 +342,7 @@ unsafe extern "C" fn thread_finalize_cb( drop(unsafe { Box::::from_raw(finalize_data.cast()) }); } -unsafe extern "C" fn call_js_cb( +unsafe extern "C" fn call_js_cb( raw_env: sys::napi_env, js_callback: sys::napi_value, context: *mut c_void, @@ -378,23 +379,42 @@ unsafe extern "C" fn call_js_cb( // If the Result is an error, pass that as the first argument. let status = match ret { Ok(values) => { - let values = values.iter().map(|v| unsafe { v.raw() }); - let args: Vec = if ES::VALUE == ErrorStrategy::CalleeHandled::VALUE { + let values = values + .into_iter() + .map(|v| unsafe { ToNapiValue::to_napi_value(raw_env, v) }); + let args: Result> = if ES::VALUE == ErrorStrategy::CalleeHandled::VALUE { let mut js_null = ptr::null_mut(); unsafe { sys::napi_get_null(raw_env, &mut js_null) }; - ::core::iter::once(js_null).chain(values).collect() + ::core::iter::once(Ok(js_null)).chain(values).collect() } else { values.collect() }; - unsafe { - sys::napi_call_function( - raw_env, - recv, - js_callback, - args.len(), - args.as_ptr(), - ptr::null_mut(), - ) + match args { + Ok(args) => unsafe { + sys::napi_call_function( + raw_env, + recv, + js_callback, + args.len(), + args.as_ptr(), + ptr::null_mut(), + ) + }, + Err(e) => match ES::VALUE { + ErrorStrategy::Fatal::VALUE => unsafe { + sys::napi_fatal_exception(raw_env, JsError::from(e).into_value(raw_env)) + }, + ErrorStrategy::CalleeHandled::VALUE => unsafe { + sys::napi_call_function( + raw_env, + recv, + js_callback, + 1, + [JsError::from(e).into_value(raw_env)].as_mut_ptr(), + ptr::null_mut(), + ) + }, + }, } } Err(e) if ES::VALUE == ErrorStrategy::Fatal::VALUE => unsafe { diff --git a/examples/napi/src/callback.rs b/examples/napi/src/callback.rs index 52f1269b..14793f61 100644 --- a/examples/napi/src/callback.rs +++ b/examples/napi/src/callback.rs @@ -1,6 +1,10 @@ use std::env; -use napi::bindgen_prelude::*; +use napi::{ + bindgen_prelude::*, + threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction, ThreadsafeFunctionCallMode}, + JsUnknown, +}; #[napi] fn get_cwd Result<()>>(callback: T) { @@ -46,3 +50,32 @@ fn read_file_content() -> Result { fn return_js_function(env: Env) -> Result { get_js_function(&env, read_file_js_function) } + +#[napi( + ts_generic_types = "T", + ts_args_type = "functionInput: () => T | Promise, callback: (err: Error | null, result: T) => void" +)] +fn callback_return_promise Result>( + env: Env, + fn_in: T, + fn_out: JsFunction, +) -> Result { + let ret = fn_in()?; + if ret.is_promise()? { + let p = Promise::::from_unknown(ret)?; + let fn_out_tsfn: ThreadsafeFunction = fn_out + .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| Ok(vec![ctx.value]))?; + env + .execute_tokio_future( + async move { + let s = p.await; + fn_out_tsfn.call(s, ThreadsafeFunctionCallMode::NonBlocking); + Ok::<(), Error>(()) + }, + |env, _| env.get_undefined(), + ) + .map(|v| v.into_unknown()) + } else { + Ok(ret) + } +}