diff --git a/napi/src/env.rs b/napi/src/env.rs index cde2109b..9438f35a 100644 --- a/napi/src/env.rs +++ b/napi/src/env.rs @@ -597,7 +597,7 @@ impl Env { R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, >( &self, - func: JsFunction, + func: &JsFunction, max_queue_size: usize, callback: R, ) -> Result> { diff --git a/napi/src/threadsafe_function.rs b/napi/src/threadsafe_function.rs index 5e394424..0de8bc22 100644 --- a/napi/src/threadsafe_function.rs +++ b/napi/src/threadsafe_function.rs @@ -2,12 +2,16 @@ use std::convert::Into; use std::marker::PhantomData; use std::os::raw::{c_char, c_void}; use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use crate::error::check_status; -use crate::{sys, Env, JsFunction, NapiValue, Result}; +use crate::{sys, Env, Error, JsFunction, NapiValue, Result, Status}; use sys::napi_threadsafe_function_call_mode; +/// ThreadSafeFunction Context object +/// the `value` is the value passed to `call` method pub struct ThreadSafeCallContext { pub env: Env, pub value: T, @@ -44,40 +48,38 @@ impl Into for ThreadsafeFunctionCallMode { /// use std::thread; /// /// use napi::{ -/// threadsafe_function::{ -/// ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, -/// }, -/// CallContext, Error, JsFunction, JsNumber, JsUndefined, Result, Status, +/// threadsafe_function::{ +/// ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, +/// }, +/// CallContext, Error, JsFunction, JsNumber, JsUndefined, Result, Status, /// }; /// #[js_function(1)] /// pub fn test_threadsafe_function(ctx: CallContext) -> Result { -/// let func = ctx.get::(0)?; +/// let func = ctx.get::(0)?; -/// let tsfn = -/// ctx -/// .env -/// .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext>| { -/// ctx -/// .value -/// .iter() -/// .map(|v| ctx.env.create_uint32(*v)) -/// .collect::>>() -/// })?; +/// let tsfn = ctx.env +/// .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext>| { +/// ctx.value +/// .iter() +/// .map(|v| ctx.env.create_uint32(*v))] +/// .collect::>>() +/// })?; -/// thread::spawn(move || { -/// let output: Vec = vec![42, 1, 2, 3]; -/// /// It's okay to call a threadsafe function multiple times. -/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking); -/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); -/// tsfn.release(ThreadsafeFunctionReleaseMode::Release); -/// }); +/// thread::spawn(move || { +/// let output: Vec = vec![42, 1, 2, 3]; +/// // It's okay to call a threadsafe function multiple times. +/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking); +/// tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); +/// tsfn.release(ThreadsafeFunctionReleaseMode::Release); +/// }); -/// ctx.env.get_undefined() +/// ctx.env.get_undefined() /// } /// ``` pub struct ThreadsafeFunction { raw_tsfn: sys::napi_threadsafe_function, + aborted: Arc, _phantom: PhantomData, } @@ -98,7 +100,7 @@ impl ThreadsafeFunction { R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, >( env: sys::napi_env, - func: JsFunction, + func: &JsFunction, max_queue_size: usize, callback: R, ) -> Result { @@ -135,59 +137,92 @@ impl ThreadsafeFunction { Ok(ThreadsafeFunction { raw_tsfn, + aborted: Arc::new(AtomicBool::new(false)), _phantom: PhantomData, }) } /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. - pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) { - let status = unsafe { + pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) -> Status { + if self.aborted.load(Ordering::Acquire) { + return Status::Closing; + } + unsafe { sys::napi_call_threadsafe_function( self.raw_tsfn, Box::into_raw(Box::new(value)) as *mut _, mode.into(), ) - }; - debug_assert!( - status == sys::napi_status::napi_ok, - "Threadsafe Function call failed" - ); - } - - /// See [napi_acquire_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_acquire_threadsafe_function) - /// for more information. - pub fn acquire(&self) { - let status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) }; - debug_assert!( - status == sys::napi_status::napi_ok, - "Threadsafe Function acquire failed" - ); - } - - /// See [napi_release_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_release_threadsafe_function) - /// for more information. - pub fn release(self, mode: ThreadsafeFunctionReleaseMode) { - let status = unsafe { sys::napi_release_threadsafe_function(self.raw_tsfn, mode.into()) }; - debug_assert!( - status == sys::napi_status::napi_ok, - "Threadsafe Function call failed" - ); + } + .into() } /// See [napi_ref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_ref_threadsafe_function) /// for more information. /// /// "ref" is a keyword so that we use "refer" here. - pub fn refer(&self, env: &Env) -> Result<()> { + pub fn refer(&mut self, env: &Env) -> Result<()> { check_status(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) }) } /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// for more information. - pub fn unref(&self, env: &Env) -> Result<()> { + pub fn unref(&mut self, env: &Env) -> Result<()> { check_status(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) }) } + + pub fn aborted(&self) -> bool { + self.aborted.load(Ordering::Acquire) + } + + pub fn abort(self) -> Result<()> { + check_status(unsafe { + sys::napi_release_threadsafe_function( + self.raw_tsfn, + sys::napi_threadsafe_function_release_mode::napi_tsfn_abort, + ) + })?; + self.aborted.store(true, Ordering::Release); + Ok(()) + } + + pub fn try_clone(&self) -> Result { + if self.aborted.load(Ordering::Acquire) { + return Err(Error::new( + Status::Closing, + format!("Thread safe function already aborted"), + )); + } + check_status(unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) })?; + Ok(Self { + raw_tsfn: self.raw_tsfn, + aborted: Arc::clone(&self.aborted), + _phantom: PhantomData, + }) + } + + /// Get the raw `ThreadSafeFunction` pointer + pub fn raw(&self) -> sys::napi_threadsafe_function { + self.raw_tsfn + } +} + +impl Drop for ThreadsafeFunction { + fn drop(&mut self) { + if !self.aborted.load(Ordering::Acquire) { + let release_status = unsafe { + sys::napi_release_threadsafe_function( + self.raw_tsfn, + sys::napi_threadsafe_function_release_mode::napi_tsfn_release, + ) + }; + assert!( + release_status == sys::Status::napi_ok, + "Threadsafe Function release failed" + ); + } + } } unsafe extern "C" fn thread_finalize_cb( diff --git a/test_module/__test__/napi4/threadsafe_function.spec.ts b/test_module/__test__/napi4/threadsafe_function.spec.ts index da44da31..6f9c2858 100644 --- a/test_module/__test__/napi4/threadsafe_function.spec.ts +++ b/test_module/__test__/napi4/threadsafe_function.spec.ts @@ -16,7 +16,11 @@ test('should get js function called from a thread', async (t) => { bindings.testThreadsafeFunction((...args: any[]) => { called += 1 try { - t.deepEqual(args, [null, 42, 1, 2, 3]) + if (args[1] === 0) { + t.deepEqual(args, [null, 0, 1, 2, 3]) + } else { + t.deepEqual(args, [null, 3, 2, 1, 0]) + } } catch (err) { reject(err) } @@ -27,3 +31,27 @@ test('should get js function called from a thread', async (t) => { }) }) }) + +test('should be able to abort tsfn', (t) => { + if (napiVersion < 4) { + t.is(bindings.testAbortThreadsafeFunction, undefined) + return + } + t.true(bindings.testAbortThreadsafeFunction(() => {})) +}) + +test('should be able to abort independent tsfn', (t) => { + if (napiVersion < 4) { + t.is(bindings.testAbortIndependentThreadsafeFunction, undefined) + return + } + t.false(bindings.testAbortIndependentThreadsafeFunction(() => {})) +}) + +test('should return Closing while calling aborted tsfn', (t) => { + if (napiVersion < 4) { + t.is(bindings.testCallAbortedThreadsafeFunction, undefined) + return + } + t.notThrows(() => bindings.testCallAbortedThreadsafeFunction(() => {})) +}) diff --git a/test_module/src/napi4/mod.rs b/test_module/src/napi4/mod.rs index fd78fc40..5a451c97 100644 --- a/test_module/src/napi4/mod.rs +++ b/test_module/src/napi4/mod.rs @@ -8,5 +8,17 @@ pub fn register_js(module: &mut Module) -> Result<()> { module.create_named_method("testThreadsafeFunction", test_threadsafe_function)?; module.create_named_method("testTsfnError", test_tsfn_error)?; module.create_named_method("testTokioReadfile", test_tokio_readfile)?; + module.create_named_method( + "testAbortThreadsafeFunction", + test_abort_threadsafe_function, + )?; + module.create_named_method( + "testAbortIndependentThreadsafeFunction", + test_abort_independent_threadsafe_function, + )?; + module.create_named_method( + "testCallAbortedThreadsafeFunction", + test_call_aborted_threadsafe_function, + )?; Ok(()) } diff --git a/test_module/src/napi4/tsfn.rs b/test_module/src/napi4/tsfn.rs index 601517de..5508f0fd 100644 --- a/test_module/src/napi4/tsfn.rs +++ b/test_module/src/napi4/tsfn.rs @@ -2,10 +2,8 @@ use std::path::Path; use std::thread; use napi::{ - threadsafe_function::{ - ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode, - }, - CallContext, Error, JsFunction, JsNumber, JsString, JsUndefined, Result, Status, + threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunctionCallMode}, + CallContext, Error, JsBoolean, JsFunction, JsNumber, JsString, JsUndefined, Result, Status, }; use tokio; @@ -16,7 +14,7 @@ pub fn test_threadsafe_function(ctx: CallContext) -> Result { let tsfn = ctx .env - .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext>| { + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext>| { ctx .value .iter() @@ -24,14 +22,80 @@ pub fn test_threadsafe_function(ctx: CallContext) -> Result { .collect::>>() })?; + let tsfn_cloned = tsfn.try_clone()?; + thread::spawn(move || { - let output: Vec = vec![42, 1, 2, 3]; + let output: Vec = vec![0, 1, 2, 3]; // It's okay to call a threadsafe function multiple times. tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking); - tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); - tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); + thread::spawn(move || { + let output: Vec = vec![3, 2, 1, 0]; + // It's okay to call a threadsafe function multiple times. + tsfn_cloned.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking); + }); + + ctx.env.get_undefined() +} + +#[js_function(1)] +pub fn test_abort_threadsafe_function(ctx: CallContext) -> Result { + let func = ctx.get::(0)?; + + let tsfn = + ctx + .env + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext>| { + ctx + .value + .iter() + .map(|v| ctx.env.create_uint32(*v)) + .collect::>>() + })?; + + let tsfn_cloned = tsfn.try_clone()?; + + tsfn_cloned.abort()?; + ctx.env.get_boolean(tsfn.aborted()) +} + +#[js_function(1)] +pub fn test_abort_independent_threadsafe_function(ctx: CallContext) -> Result { + let func = ctx.get::(0)?; + + let tsfn = ctx + .env + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext| { + ctx.env.create_uint32(ctx.value).map(|v| vec![v]) + })?; + + let tsfn_other = + ctx + .env + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext| { + ctx.env.create_uint32(ctx.value).map(|v| vec![v]) + })?; + + tsfn_other.abort()?; + ctx.env.get_boolean(tsfn.aborted()) +} + +#[js_function(1)] +pub fn test_call_aborted_threadsafe_function(ctx: CallContext) -> Result { + let func = ctx.get::(0)?; + + let tsfn = ctx + .env + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext| { + ctx.env.create_uint32(ctx.value).map(|v| vec![v]) + })?; + + let tsfn_clone = tsfn.try_clone()?; + tsfn_clone.abort()?; + + let call_status = tsfn.call(Ok(1), ThreadsafeFunctionCallMode::NonBlocking); + assert!(call_status == Status::Closing); ctx.env.get_undefined() } @@ -40,7 +104,7 @@ pub fn test_tsfn_error(ctx: CallContext) -> Result { let func = ctx.get::(0)?; let tsfn = ctx .env - .create_threadsafe_function(func, 0, |ctx: ThreadSafeCallContext<()>| { + .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext<()>| { ctx.env.get_undefined().map(|v| vec![v]) })?; thread::spawn(move || { @@ -48,7 +112,6 @@ pub fn test_tsfn_error(ctx: CallContext) -> Result { Err(Error::new(Status::GenericFailure, "invalid".to_owned())), ThreadsafeFunctionCallMode::Blocking, ); - tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); ctx.env.get_undefined() @@ -69,7 +132,7 @@ pub fn test_tokio_readfile(ctx: CallContext) -> Result { let tsfn = ctx .env - .create_threadsafe_function(js_func, 0, |ctx: ThreadSafeCallContext>| { + .create_threadsafe_function(&js_func, 0, |ctx: ThreadSafeCallContext>| { ctx .env .create_buffer_with_data(ctx.value) @@ -81,7 +144,6 @@ pub fn test_tokio_readfile(ctx: CallContext) -> Result { rt.block_on(async move { let ret = read_file_content(&Path::new(&path_str)).await; tsfn.call(ret, ThreadsafeFunctionCallMode::Blocking); - tsfn.release(ThreadsafeFunctionReleaseMode::Release); }); ctx.env.get_undefined()