diff --git a/crates/napi/src/threadsafe_function.rs b/crates/napi/src/threadsafe_function.rs index 0ae37f3c..a38e7e0e 100644 --- a/crates/napi/src/threadsafe_function.rs +++ b/crates/napi/src/threadsafe_function.rs @@ -5,8 +5,8 @@ use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::c_void; use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; use crate::bindgen_runtime::ToNapiValue; use crate::{check_status, sys, Env, Error, JsError, Result, Status}; @@ -146,23 +146,32 @@ type_level_enum! { /// ``` pub struct ThreadsafeFunction { raw_tsfn: sys::napi_threadsafe_function, - aborted: Arc, + aborted: Arc>, + ref_count: Arc, _phantom: PhantomData<(T, ES)>, } impl Clone for ThreadsafeFunction { fn clone(&self) -> Self { - if !self.aborted.load(Ordering::Acquire) { + let is_aborted = self.aborted.lock().unwrap(); + if !*is_aborted { let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) }; debug_assert!( acquire_status == sys::Status::napi_ok, "Acquire threadsafe function failed in clone" ); + } else { + panic!("ThreadsafeFunction was aborted, can not clone it"); } + self.ref_count.fetch_add(1, Ordering::AcqRel); + + drop(is_aborted); + Self { raw_tsfn: self.raw_tsfn, aborted: Arc::clone(&self.aborted), + ref_count: Arc::clone(&self.ref_count), _phantom: PhantomData, } } @@ -194,13 +203,8 @@ impl ThreadsafeFunction { let initial_thread_count = 1usize; let mut raw_tsfn = ptr::null_mut(); let ptr = Box::into_raw(Box::new(callback)) as *mut c_void; - let aborted = Arc::new(AtomicBool::new(false)); - let aborted_ptr = Arc::into_raw(aborted.clone()) as *mut c_void; - // `aborted_ptr` is passed into both `finalize_callback` and `env_cleanup_callback`. - // So increase strong count here to prevent it from being dropped twice. - unsafe { - Arc::increment_strong_count(aborted_ptr); - } + let aborted = Arc::new(Mutex::new(false)); + let aborted_ptr = Arc::into_raw(aborted.clone()); check_status!(unsafe { sys::napi_create_threadsafe_function( env, @@ -211,17 +215,16 @@ impl ThreadsafeFunction { initial_thread_count, ptr, Some(thread_finalize_cb::), - aborted_ptr, + aborted_ptr as *mut c_void, Some(call_js_cb::), &mut raw_tsfn, ) })?; - check_status!(unsafe { sys::napi_add_env_cleanup_hook(env, Some(cleanup_cb), aborted_ptr) })?; - Ok(ThreadsafeFunction { raw_tsfn, aborted, + ref_count: Arc::new(AtomicUsize::new(initial_thread_count)), _phantom: PhantomData, }) } @@ -231,39 +234,48 @@ impl ThreadsafeFunction { /// /// "ref" is a keyword so that we use "refer" here. pub fn refer(&mut self, env: &Env) -> Result<()> { - if self.aborted.load(Ordering::Acquire) { + let is_aborted = self.aborted.lock().unwrap(); + if *is_aborted { return Err(Error::new( Status::Closing, "Can not ref, Thread safe function already aborted".to_string(), )); } + drop(is_aborted); + self.ref_count.fetch_add(1, Ordering::AcqRel); check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) }) } /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// for more information. pub fn unref(&mut self, env: &Env) -> Result<()> { - if self.aborted.load(Ordering::Acquire) { + let is_aborted = self.aborted.lock().unwrap(); + if *is_aborted { return Err(Error::new( Status::Closing, "Can not unref, Thread safe function already aborted".to_string(), )); } + self.ref_count.fetch_sub(1, Ordering::AcqRel); check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) }) } pub fn aborted(&self) -> bool { - self.aborted.load(Ordering::Acquire) + let is_aborted = self.aborted.lock().unwrap(); + *is_aborted } pub fn abort(self) -> Result<()> { - check_status!(unsafe { - sys::napi_release_threadsafe_function( - self.raw_tsfn, - sys::ThreadsafeFunctionReleaseMode::abort, - ) - })?; - self.aborted.store(true, Ordering::Release); + let mut is_aborted = self.aborted.lock().unwrap(); + if !*is_aborted { + check_status!(unsafe { + sys::napi_release_threadsafe_function( + self.raw_tsfn, + sys::ThreadsafeFunctionReleaseMode::abort, + ) + })?; + } + *is_aborted = true; Ok(()) } @@ -277,21 +289,18 @@ impl ThreadsafeFunction { /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) -> Status { - if self.aborted.load(Ordering::Acquire) { + let is_aborted = self.aborted.lock().unwrap(); + if *is_aborted { return Status::Closing; } - let status = unsafe { + unsafe { sys::napi_call_threadsafe_function( self.raw_tsfn, - Box::into_raw(Box::new(value)) as *mut _, + Box::into_raw(Box::new(value)) as *mut c_void, mode.into(), ) } - .into(); - if status == Status::Closing { - self.aborted.store(true, Ordering::Release); - } - status + .into() } } @@ -299,27 +308,25 @@ impl ThreadsafeFunction { /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { - if self.aborted.load(Ordering::Acquire) { + let is_aborted = self.aborted.lock().unwrap(); + if *is_aborted { return Status::Closing; } - let status = unsafe { + unsafe { sys::napi_call_threadsafe_function( self.raw_tsfn, - Box::into_raw(Box::new(value)) as *mut _, + Box::into_raw(Box::new(value)) as *mut c_void, mode.into(), ) } - .into(); - if status == Status::Closing { - self.aborted.store(true, Ordering::Release); - } - status + .into() } } impl Drop for ThreadsafeFunction { fn drop(&mut self) { - if !self.aborted.load(Ordering::Acquire) { + let mut is_aborted = self.aborted.lock().unwrap(); + if !*is_aborted && self.ref_count.load(Ordering::Acquire) <= 1 { let release_status = unsafe { sys::napi_release_threadsafe_function( self.raw_tsfn, @@ -328,17 +335,17 @@ impl Drop for ThreadsafeFunction { }; assert!( release_status == sys::Status::napi_ok, - "Threadsafe Function release failed" + "Threadsafe Function release failed {:?}", + Status::from(release_status) ); + *is_aborted = true; + } else { + self.ref_count.fetch_sub(1, Ordering::Release); } + drop(is_aborted); } } -unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) { - let aborted = unsafe { Arc::::from_raw(cleanup_data.cast()) }; - aborted.store(true, Ordering::Release); -} - unsafe extern "C" fn thread_finalize_cb( _raw_env: sys::napi_env, finalize_data: *mut c_void, @@ -348,8 +355,9 @@ unsafe extern "C" fn thread_finalize_cb( { // cleanup drop(unsafe { Box::::from_raw(finalize_data.cast()) }); - let aborted = unsafe { Arc::::from_raw(finalize_hint.cast()) }; - aborted.store(true, Ordering::Release); + let aborted = unsafe { Arc::>::from_raw(finalize_hint.cast()) }; + let mut is_aborted = aborted.lock().unwrap(); + *is_aborted = true; } unsafe extern "C" fn call_js_cb( diff --git a/memory-testing/index.mjs b/memory-testing/index.mjs index 8f6c4fd1..6dddeb9f 100644 --- a/memory-testing/index.mjs +++ b/memory-testing/index.mjs @@ -3,3 +3,4 @@ import { createSuite } from './test-util.mjs' await createSuite('reference') await createSuite('tokio-future') await createSuite('serde') +await createSuite('tsfn') diff --git a/memory-testing/src/lib.rs b/memory-testing/src/lib.rs index e9ac5814..9720747f 100644 --- a/memory-testing/src/lib.rs +++ b/memory-testing/src/lib.rs @@ -1,4 +1,9 @@ -use napi::{bindgen_prelude::*, Env}; +use std::thread::spawn; + +use napi::{ + bindgen_prelude::*, + threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction, ThreadsafeFunctionCallMode}, +}; #[macro_use] extern crate napi_derive; @@ -111,10 +116,28 @@ impl MemoryHolder { pub struct ChildReference(SharedReference); #[napi] - impl ChildReference { #[napi] pub fn count(&self) -> u32 { self.0.count() as u32 } } + +#[napi] +pub fn leaking_func(env: Env, func: JsFunction) -> napi::Result<()> { + let mut tsfn: ThreadsafeFunction = + func.create_threadsafe_function(0, |mut ctx: ThreadSafeCallContext| { + ctx.env.adjust_external_memory(ctx.value.len() as i64)?; + ctx + .env + .create_string_from_std(ctx.value) + .map(|js_string| vec![js_string]) + })?; + + tsfn.unref(&env)?; + spawn(move || { + tsfn.call(Ok("foo".into()), ThreadsafeFunctionCallMode::Blocking); + }); + + Ok(()) +} diff --git a/memory-testing/tsfn.mjs b/memory-testing/tsfn.mjs new file mode 100644 index 00000000..9362aafb --- /dev/null +++ b/memory-testing/tsfn.mjs @@ -0,0 +1,22 @@ +import { createRequire } from 'module' +import { setTimeout } from 'timers/promises' + +import { displayMemoryUsageFromNode } from './util.mjs' + +const initialMemoryUsage = process.memoryUsage() + +const require = createRequire(import.meta.url) + +const api = require(`./index.node`) + +let i = 1 +// eslint-disable-next-line no-constant-condition +while (true) { + api.leakingFunc(() => {}) + if (i % 100000 === 0) { + await setTimeout(100) + global.gc?.() + displayMemoryUsageFromNode(initialMemoryUsage) + } + i++ +}