From c5f2b6699d0f435150ca9fbd0c6969d3d1f8e9e5 Mon Sep 17 00:00:00 2001 From: LongYinan Date: Sun, 4 Oct 2020 17:10:52 +0800 Subject: [PATCH] refactor(napi): thread safe function redesign --- napi/src/env.rs | 16 +++ napi/src/promise.rs | 2 - napi/src/threadsafe_function.rs | 241 +++++++++++++++++--------------- test_module/src/napi4/mod.rs | 2 +- test_module/src/napi4/tsfn.rs | 104 ++++++-------- 5 files changed, 184 insertions(+), 181 deletions(-) diff --git a/napi/src/env.rs b/napi/src/env.rs index e8f14862..8311dc78 100644 --- a/napi/src/env.rs +++ b/napi/src/env.rs @@ -16,6 +16,8 @@ use crate::{sys, Error, NodeVersion, Result, Status}; use crate::js_values::{De, Ser}; #[cfg(all(any(feature = "libuv", feature = "tokio_rt"), napi4))] use crate::promise; +#[cfg(napi4)] +use crate::threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction}; #[cfg(all(feature = "tokio_rt", napi4))] use crate::tokio_rt::{get_tokio_sender, Message}; #[cfg(all(feature = "libuv", napi4))] @@ -554,6 +556,20 @@ impl Env { }) } + #[cfg(napi4)] + pub fn create_threadsafe_function< + T: Send, + V: NapiValue, + R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, + >( + &self, + func: JsFunction, + max_queue_size: u64, + callback: R, + ) -> Result> { + ThreadsafeFunction::create(self.0, func, max_queue_size, callback) + } + #[cfg(all(feature = "libuv", napi4))] pub fn execute< T: 'static + Send, diff --git a/napi/src/promise.rs b/napi/src/promise.rs index 0ea2abb7..34fa4a65 100644 --- a/napi/src/promise.rs +++ b/napi/src/promise.rs @@ -71,8 +71,6 @@ pub(crate) struct TSFNValue(sys::napi_threadsafe_function); unsafe impl Send for TSFNValue {} -unsafe impl Sync for TSFNValue {} - #[inline] pub(crate) async fn resolve_from_future>>( tsfn_value: TSFNValue, diff --git a/napi/src/threadsafe_function.rs b/napi/src/threadsafe_function.rs index 5dd3f741..7be56e1b 100644 --- a/napi/src/threadsafe_function.rs +++ b/napi/src/threadsafe_function.rs @@ -1,13 +1,19 @@ use std::convert::Into; +use std::marker::PhantomData; use std::os::raw::{c_char, c_void}; use std::ptr; use crate::error::check_status; -use crate::{sys, Env, JsFunction, JsUnknown, Result}; +use crate::{sys, Env, JsFunction, NapiValue, Result}; use sys::napi_threadsafe_function_call_mode; use sys::napi_threadsafe_function_release_mode; +pub struct ThreadSafeCallContext { + pub env: Env, + pub value: T, +} + #[repr(u8)] pub enum ThreadsafeFunctionCallMode { NonBlocking, @@ -46,12 +52,6 @@ impl Into for ThreadsafeFunctionReleaseMo } } -pub trait ToJs: Copy + Clone { - type Output; - - fn resolve(&self, env: &mut Env, output: Self::Output) -> Result>; -} - /// Communicate with the addon's main thread by invoking a JavaScript function from other threads. /// /// ## Example @@ -62,71 +62,71 @@ pub trait ToJs: Copy + Clone { /// extern crate napi_derive; /// /// use std::thread; +/// /// use napi::{ -/// Number, Result, Env, CallContext, JsUndefined, JsFunction, +/// threadsafe_function::{ +/// ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, +/// }, +/// CallContext, Error, JsFunction, JsNumber, JsUndefined, Result, Status, /// }; -/// use napi::threadsafe_function::{ -/// ToJs, ThreadsafeFunction, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, -/// }; -/// -/// // Define a struct for handling the data passed from `ThreadsafeFunction::call` -/// // and return the data to be used for the js callback. -/// #[derive(Clone, Copy)] -/// struct HandleNumber; -/// -/// impl ToJs for HandleNumber { -/// type Output = u8; -/// -/// fn resolve(&self, env: &mut Env, output: Self::Output) -> Result> { -/// let value = env.create_uint32(output as u32)?.into_unknown()?; -/// // The first argument in the NodeJS callback will be either a null or an error -/// // depending on the result returned by this function. -/// // If this Result is Ok, the first argument will be null. -/// // If this Result is Err, the first argument will be the error. -/// Ok(vec![value]) -/// } -/// } -/// + /// #[js_function(1)] -/// fn test_threadsafe_function(ctx: CallContext) -> Result { -/// // The callback function from js which will be called in `ThreadsafeFunction::call`. +/// pub fn test_threadsafe_function(ctx: CallContext) -> Result { /// let func = ctx.get::(0)?; -/// -/// let to_js = HandleNumber; -/// let tsfn = ThreadsafeFunction::create(ctx.env, func, to_js, 0)?; -/// + +/// let tsfn = +/// ctx +/// .env +/// .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext>| { +/// ctx +/// .value +/// .iter() +/// .map(|v| ctx.env.create_uint32(*v)) +/// .collect::>>() +/// })?; + /// thread::spawn(move || { -/// let output: u8 = 42; -/// // It's okay to call a threadsafe function multiple times. -/// tsfn.call(Ok(output), ThreadsafeFunctionCallMode::Blocking).unwrap(); -/// tsfn.call(Ok(output), ThreadsafeFunctionCallMode::Blocking).unwrap(); -/// // We should call `ThreadsafeFunction::release` manually when we don't -/// // need the instance anymore, or it will prevent Node.js from exiting -/// // automatically and possibly cause memory leaks. -/// tsfn.release(ThreadsafeFunctionReleaseMode::Release).unwrap(); +/// let output: Vec = vec![42, 1, 2, 3]; +/// /// It's okay to call a threadsafe function multiple times. +/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking); +/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); +/// tsfn.release(ThreadsafeFunctionReleaseMode::Release); /// }); -/// + /// ctx.env.get_undefined() /// } /// ``` -#[derive(Debug, Clone, Copy)] -pub struct ThreadsafeFunction { - raw_value: sys::napi_threadsafe_function, - to_js: T, +pub struct ThreadsafeFunction { + raw_tsfn: sys::napi_threadsafe_function, + _phantom: PhantomData, } -unsafe impl Send for ThreadsafeFunction {} -unsafe impl Sync for ThreadsafeFunction {} +unsafe impl Send for ThreadsafeFunction {} +unsafe impl Sync for ThreadsafeFunction {} -impl ThreadsafeFunction { +#[repr(transparent)] +struct ThreadSafeContext( + Box) -> Result>>, +); + +impl ThreadsafeFunction { /// See [napi_create_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_create_threadsafe_function) /// for more information. - pub fn create(env: &Env, func: JsFunction, to_js: T, max_queue_size: u64) -> Result { + #[inline(always)] + pub fn create< + V: NapiValue, + R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, + >( + env: sys::napi_env, + func: JsFunction, + max_queue_size: u64, + callback: R, + ) -> Result { let mut async_resource_name = ptr::null_mut(); let s = "napi_rs_threadsafe_function"; check_status(unsafe { sys::napi_create_string_utf8( - env.0, + env, s.as_ptr() as *const c_char, s.len() as u64, &mut async_resource_name, @@ -134,59 +134,65 @@ impl ThreadsafeFunction { })?; let initial_thread_count: u64 = 1; - let mut result = ptr::null_mut(); - let tsfn = ThreadsafeFunction { - to_js, - raw_value: result, - }; - - let ptr = Box::into_raw(Box::from(tsfn)) as *mut _ as *mut c_void; - - let status = unsafe { + let mut raw_tsfn = ptr::null_mut(); + let context = ThreadSafeContext(Box::from(callback)); + let ptr = Box::into_raw(Box::new(context)) as *mut _; + check_status(unsafe { sys::napi_create_threadsafe_function( - env.0, + env, func.0.value, ptr::null_mut(), async_resource_name, max_queue_size, initial_thread_count, ptr, - Some(thread_finalize_cb::), + Some(thread_finalize_cb::), ptr, - Some(call_js_cb::), - &mut result, + Some(call_js_cb::), + &mut raw_tsfn, ) - }; - check_status(status)?; + })?; Ok(ThreadsafeFunction { - to_js, - raw_value: result, + raw_tsfn, + _phantom: PhantomData, }) } /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. - pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) -> Result<()> { - check_status(unsafe { + pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) { + let status = unsafe { sys::napi_call_threadsafe_function( - self.raw_value, - Box::into_raw(Box::from(value)) as *mut _ as *mut c_void, + self.raw_tsfn, + Box::into_raw(Box::new(value)) as *mut _, mode.into(), ) - }) + }; + debug_assert!( + status == sys::napi_status::napi_ok, + "Threadsafe Function call failed" + ); } /// See [napi_acquire_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_acquire_threadsafe_function) /// for more information. - pub fn acquire(&self) -> Result<()> { - check_status(unsafe { sys::napi_acquire_threadsafe_function(self.raw_value) }) + pub fn acquire(&self) { + let status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) }; + debug_assert!( + status == sys::napi_status::napi_ok, + "Threadsafe Function acquire failed" + ); } /// See [napi_release_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_release_threadsafe_function) /// for more information. - pub fn release(&self, mode: ThreadsafeFunctionReleaseMode) -> Result<()> { - check_status(unsafe { sys::napi_release_threadsafe_function(self.raw_value, mode.into()) }) + pub fn release(self, mode: ThreadsafeFunctionReleaseMode) { + let status = unsafe { sys::napi_release_threadsafe_function(self.raw_tsfn, mode.into()) }; + debug_assert!( + status == sys::napi_status::napi_ok, + "Threadsafe Function call failed" + ); } /// See [napi_ref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_ref_threadsafe_function) @@ -194,73 +200,76 @@ impl ThreadsafeFunction { /// /// "ref" is a keyword so that we use "refer" here. pub fn refer(&self, env: &Env) -> Result<()> { - check_status(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_value) }) + check_status(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) }) } /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// for more information. pub fn unref(&self, env: &Env) -> Result<()> { - check_status(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_value) }) + check_status(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) }) } } -unsafe extern "C" fn thread_finalize_cb( +unsafe extern "C" fn thread_finalize_cb( _raw_env: sys::napi_env, finalize_data: *mut c_void, _finalize_hint: *mut c_void, ) { // cleanup - Box::from_raw(finalize_data as *mut ThreadsafeFunction); + Box::from_raw(finalize_data as *mut ThreadSafeContext); } -unsafe extern "C" fn call_js_cb( +unsafe extern "C" fn call_js_cb( raw_env: sys::napi_env, js_callback: sys::napi_value, context: *mut c_void, data: *mut c_void, ) { - let mut env = Env::from_raw(raw_env); let mut recv = ptr::null_mut(); sys::napi_get_undefined(raw_env, &mut recv); - let tsfn = Box::leak(Box::from_raw(context as *mut ThreadsafeFunction)); - let val = Box::from_raw(data as *mut Result); + let ctx = Box::leak(Box::from_raw(context as *mut ThreadSafeContext)); + let val = Box::from_raw(data as *mut Result); - let ret = val.and_then(|v| tsfn.to_js.resolve(&mut env, v)); + let ret = val.and_then(|v| { + (ctx.0)(ThreadSafeCallContext { + env: Env::from_raw(raw_env), + value: v, + }) + }); let status; // Follow async callback conventions: https://nodejs.org/en/knowledge/errors/what-are-the-error-conventions/ // Check if the Result is okay, if so, pass a null as the first (error) argument automatically. // If the Result is an error, pass that as the first argument. - if ret.is_ok() { - let values = ret.unwrap(); - let js_null = env.get_null().unwrap(); - let mut raw_values: Vec = vec![]; - raw_values.push(js_null.0.value); - for item in values.iter() { - raw_values.push(item.0.value) + match ret { + Ok(values) => { + let mut js_null = ptr::null_mut(); + sys::napi_get_null(raw_env, &mut js_null); + let args_length = values.len() + 1; + let mut args: Vec = Vec::with_capacity(args_length); + args.push(js_null); + args.extend(values.iter().map(|v| v.raw())); + status = sys::napi_call_function( + raw_env, + recv, + js_callback, + args_length as _, + args.as_ptr(), + ptr::null_mut(), + ); + } + Err(e) => { + status = sys::napi_call_function( + raw_env, + recv, + js_callback, + 1, + [e.into_raw(raw_env)].as_mut_ptr(), + ptr::null_mut(), + ); } - - status = sys::napi_call_function( - raw_env, - recv, - js_callback, - (values.len() + 1) as u64, - raw_values.as_ptr(), - ptr::null_mut(), - ); - } else { - let mut err = env.create_error(ret.err().unwrap()).unwrap(); - status = sys::napi_call_function( - raw_env, - recv, - js_callback, - 1, - &mut err.0.value, - ptr::null_mut(), - ); } - debug_assert!(status == sys::napi_status::napi_ok, "CallJsCB failed"); } diff --git a/test_module/src/napi4/mod.rs b/test_module/src/napi4/mod.rs index 52296cae..fd78fc40 100644 --- a/test_module/src/napi4/mod.rs +++ b/test_module/src/napi4/mod.rs @@ -5,8 +5,8 @@ mod tsfn; use tsfn::*; pub fn register_js(module: &mut Module) -> Result<()> { - module.create_named_method("testTsfnError", test_tsfn_error)?; module.create_named_method("testThreadsafeFunction", test_threadsafe_function)?; + module.create_named_method("testTsfnError", test_tsfn_error)?; module.create_named_method("testTokioReadfile", test_tokio_readfile)?; Ok(()) } diff --git a/test_module/src/napi4/tsfn.rs b/test_module/src/napi4/tsfn.rs index 214363c9..a4a01608 100644 --- a/test_module/src/napi4/tsfn.rs +++ b/test_module/src/napi4/tsfn.rs @@ -1,47 +1,35 @@ use std::path::Path; use std::thread; -use napi::threadsafe_function::{ - ThreadsafeFunction, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, ToJs, +use napi::{ + threadsafe_function::{ + ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, + }, + CallContext, Error, JsFunction, JsNumber, JsString, JsUndefined, Result, Status, }; -use napi::{CallContext, Env, Error, JsFunction, JsString, JsUndefined, JsUnknown, Result, Status}; use tokio; -#[derive(Clone, Copy)] -struct HandleNumber; - -impl ToJs for HandleNumber { - type Output = Vec; - - fn resolve(&self, env: &mut Env, output: Self::Output) -> Result> { - let mut items: Vec = vec![]; - for item in output.iter() { - let value = env.create_uint32((*item) as u32)?.into_unknown(); - items.push(value); - } - Ok(items) - } -} - #[js_function(1)] pub fn test_threadsafe_function(ctx: CallContext) -> Result { let func = ctx.get::(0)?; - let to_js = HandleNumber; - let tsfn = ThreadsafeFunction::create(ctx.env, func, to_js, 0)?; + let tsfn = + ctx + .env + .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext>| { + ctx + .value + .iter() + .map(|v| ctx.env.create_uint32(*v)) + .collect::>>() + })?; thread::spawn(move || { - let output: Vec = vec![42, 1, 2, 3]; + let output: Vec = vec![42, 1, 2, 3]; // It's okay to call a threadsafe function multiple times. - tsfn - .call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking) - .unwrap(); - tsfn - .call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking) - .unwrap(); - tsfn - .release(ThreadsafeFunctionReleaseMode::Release) - .unwrap(); + tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking); + tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); + tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); ctx.env.get_undefined() @@ -50,36 +38,22 @@ pub fn test_threadsafe_function(ctx: CallContext) -> Result { #[js_function(1)] pub fn test_tsfn_error(ctx: CallContext) -> Result { let func = ctx.get::(0)?; - let to_js = HandleNumber; - let tsfn = ThreadsafeFunction::create(ctx.env, func, to_js, 0)?; - + let tsfn = ctx + .env + .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext<()>| { + ctx.env.get_undefined().map(|v| vec![v]) + })?; thread::spawn(move || { - tsfn - .call( - Err(Error::new(Status::Unknown, "invalid".to_owned())), - ThreadsafeFunctionCallMode::Blocking, - ) - .unwrap(); - tsfn - .release(ThreadsafeFunctionReleaseMode::Release) - .unwrap(); + tsfn.call( + Err(Error::new(Status::Unknown, "invalid".to_owned())), + ThreadsafeFunctionCallMode::Blocking, + ); + tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); ctx.env.get_undefined() } -#[derive(Copy, Clone)] -struct HandleBuffer; - -impl ToJs for HandleBuffer { - type Output = Vec; - - fn resolve(&self, env: &mut Env, output: Self::Output) -> Result> { - let value = env.create_buffer_with_data(output.to_vec())?.into_unknown(); - Ok(vec![value]) - } -} - async fn read_file_content(filepath: &Path) -> Result> { tokio::fs::read(filepath) .await @@ -92,16 +66,22 @@ pub fn test_tokio_readfile(ctx: CallContext) -> Result { let js_func = ctx.get::(1)?; let path_str = js_filepath.into_utf8()?.to_owned()?; - let to_js = HandleBuffer; - let tsfn = ThreadsafeFunction::create(ctx.env, js_func, to_js, 0)?; - let mut rt = tokio::runtime::Runtime::new().unwrap(); + let tsfn = + ctx + .env + .create_threadsafe_function(js_func, 0, |ctx: ThreadSafeCallContext>| { + ctx + .env + .create_buffer_with_data(ctx.value) + .map(|v| vec![v.into_raw()]) + })?; + let mut rt = tokio::runtime::Runtime::new() + .map_err(|e| Error::from_reason(format!("Create tokio runtime failed {}", e)))?; rt.block_on(async move { let ret = read_file_content(&Path::new(&path_str)).await; - let _ = tsfn.call(ret, ThreadsafeFunctionCallMode::Blocking); - tsfn - .release(ThreadsafeFunctionReleaseMode::Release) - .unwrap(); + tsfn.call(ret, ThreadsafeFunctionCallMode::Blocking); + tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); ctx.env.get_undefined()