Merge pull request #1234 from napi-rs/mutex-in-tsfn

This commit is contained in:
LongYinan 2022-07-10 10:34:55 +08:00 committed by GitHub
commit 2f59c6ae91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 105 additions and 51 deletions

View file

@ -5,8 +5,8 @@ use std::ffi::CString;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::c_void; use std::os::raw::c_void;
use std::ptr; use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use crate::bindgen_runtime::ToNapiValue; use crate::bindgen_runtime::ToNapiValue;
use crate::{check_status, sys, Env, Error, JsError, Result, Status}; use crate::{check_status, sys, Env, Error, JsError, Result, Status};
@ -146,23 +146,32 @@ type_level_enum! {
/// ``` /// ```
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> { pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
raw_tsfn: sys::napi_threadsafe_function, raw_tsfn: sys::napi_threadsafe_function,
aborted: Arc<AtomicBool>, aborted: Arc<Mutex<bool>>,
ref_count: Arc<AtomicUsize>,
_phantom: PhantomData<(T, ES)>, _phantom: PhantomData<(T, ES)>,
} }
impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> { impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
if !self.aborted.load(Ordering::Acquire) { let is_aborted = self.aborted.lock().unwrap();
if !*is_aborted {
let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) }; let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) };
debug_assert!( debug_assert!(
acquire_status == sys::Status::napi_ok, acquire_status == sys::Status::napi_ok,
"Acquire threadsafe function failed in clone" "Acquire threadsafe function failed in clone"
); );
} else {
panic!("ThreadsafeFunction was aborted, can not clone it");
} }
self.ref_count.fetch_add(1, Ordering::AcqRel);
drop(is_aborted);
Self { Self {
raw_tsfn: self.raw_tsfn, raw_tsfn: self.raw_tsfn,
aborted: Arc::clone(&self.aborted), aborted: Arc::clone(&self.aborted),
ref_count: Arc::clone(&self.ref_count),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -194,13 +203,8 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
let initial_thread_count = 1usize; let initial_thread_count = 1usize;
let mut raw_tsfn = ptr::null_mut(); let mut raw_tsfn = ptr::null_mut();
let ptr = Box::into_raw(Box::new(callback)) as *mut c_void; let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
let aborted = Arc::new(AtomicBool::new(false)); let aborted = Arc::new(Mutex::new(false));
let aborted_ptr = Arc::into_raw(aborted.clone()) as *mut c_void; let aborted_ptr = Arc::into_raw(aborted.clone());
// `aborted_ptr` is passed into both `finalize_callback` and `env_cleanup_callback`.
// So increase strong count here to prevent it from being dropped twice.
unsafe {
Arc::increment_strong_count(aborted_ptr);
}
check_status!(unsafe { check_status!(unsafe {
sys::napi_create_threadsafe_function( sys::napi_create_threadsafe_function(
env, env,
@ -211,17 +215,16 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
initial_thread_count, initial_thread_count,
ptr, ptr,
Some(thread_finalize_cb::<T, V, R>), Some(thread_finalize_cb::<T, V, R>),
aborted_ptr, aborted_ptr as *mut c_void,
Some(call_js_cb::<T, V, R, ES>), Some(call_js_cb::<T, V, R, ES>),
&mut raw_tsfn, &mut raw_tsfn,
) )
})?; })?;
check_status!(unsafe { sys::napi_add_env_cleanup_hook(env, Some(cleanup_cb), aborted_ptr) })?;
Ok(ThreadsafeFunction { Ok(ThreadsafeFunction {
raw_tsfn, raw_tsfn,
aborted, aborted,
ref_count: Arc::new(AtomicUsize::new(initial_thread_count)),
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
@ -231,39 +234,48 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
/// ///
/// "ref" is a keyword so that we use "refer" here. /// "ref" is a keyword so that we use "refer" here.
pub fn refer(&mut self, env: &Env) -> Result<()> { pub fn refer(&mut self, env: &Env) -> Result<()> {
if self.aborted.load(Ordering::Acquire) { let is_aborted = self.aborted.lock().unwrap();
if *is_aborted {
return Err(Error::new( return Err(Error::new(
Status::Closing, Status::Closing,
"Can not ref, Thread safe function already aborted".to_string(), "Can not ref, Thread safe function already aborted".to_string(),
)); ));
} }
drop(is_aborted);
self.ref_count.fetch_add(1, Ordering::AcqRel);
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) }) check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
} }
/// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
/// for more information. /// for more information.
pub fn unref(&mut self, env: &Env) -> Result<()> { pub fn unref(&mut self, env: &Env) -> Result<()> {
if self.aborted.load(Ordering::Acquire) { let is_aborted = self.aborted.lock().unwrap();
if *is_aborted {
return Err(Error::new( return Err(Error::new(
Status::Closing, Status::Closing,
"Can not unref, Thread safe function already aborted".to_string(), "Can not unref, Thread safe function already aborted".to_string(),
)); ));
} }
self.ref_count.fetch_sub(1, Ordering::AcqRel);
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) }) check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) })
} }
pub fn aborted(&self) -> bool { pub fn aborted(&self) -> bool {
self.aborted.load(Ordering::Acquire) let is_aborted = self.aborted.lock().unwrap();
*is_aborted
} }
pub fn abort(self) -> Result<()> { pub fn abort(self) -> Result<()> {
let mut is_aborted = self.aborted.lock().unwrap();
if !*is_aborted {
check_status!(unsafe { check_status!(unsafe {
sys::napi_release_threadsafe_function( sys::napi_release_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
sys::ThreadsafeFunctionReleaseMode::abort, sys::ThreadsafeFunctionReleaseMode::abort,
) )
})?; })?;
self.aborted.store(true, Ordering::Release); }
*is_aborted = true;
Ok(()) Ok(())
} }
@ -277,21 +289,18 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
/// for more information. /// for more information.
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
if self.aborted.load(Ordering::Acquire) { let is_aborted = self.aborted.lock().unwrap();
if *is_aborted {
return Status::Closing; return Status::Closing;
} }
let status = unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
Box::into_raw(Box::new(value)) as *mut _, Box::into_raw(Box::new(value)) as *mut c_void,
mode.into(), mode.into(),
) )
} }
.into(); .into()
if status == Status::Closing {
self.aborted.store(true, Ordering::Release);
}
status
} }
} }
@ -299,27 +308,25 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
/// for more information. /// for more information.
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
if self.aborted.load(Ordering::Acquire) { let is_aborted = self.aborted.lock().unwrap();
if *is_aborted {
return Status::Closing; return Status::Closing;
} }
let status = unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
Box::into_raw(Box::new(value)) as *mut _, Box::into_raw(Box::new(value)) as *mut c_void,
mode.into(), mode.into(),
) )
} }
.into(); .into()
if status == Status::Closing {
self.aborted.store(true, Ordering::Release);
}
status
} }
} }
impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> { impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
fn drop(&mut self) { fn drop(&mut self) {
if !self.aborted.load(Ordering::Acquire) { let mut is_aborted = self.aborted.lock().unwrap();
if !*is_aborted && self.ref_count.load(Ordering::Acquire) <= 1 {
let release_status = unsafe { let release_status = unsafe {
sys::napi_release_threadsafe_function( sys::napi_release_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
@ -328,17 +335,17 @@ impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
}; };
assert!( assert!(
release_status == sys::Status::napi_ok, release_status == sys::Status::napi_ok,
"Threadsafe Function release failed" "Threadsafe Function release failed {:?}",
Status::from(release_status)
); );
*is_aborted = true;
} else {
self.ref_count.fetch_sub(1, Ordering::Release);
} }
drop(is_aborted);
} }
} }
unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) {
let aborted = unsafe { Arc::<AtomicBool>::from_raw(cleanup_data.cast()) };
aborted.store(true, Ordering::Release);
}
unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>( unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
_raw_env: sys::napi_env, _raw_env: sys::napi_env,
finalize_data: *mut c_void, finalize_data: *mut c_void,
@ -348,8 +355,9 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
{ {
// cleanup // cleanup
drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) }); drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) });
let aborted = unsafe { Arc::<AtomicBool>::from_raw(finalize_hint.cast()) }; let aborted = unsafe { Arc::<Mutex<bool>>::from_raw(finalize_hint.cast()) };
aborted.store(true, Ordering::Release); let mut is_aborted = aborted.lock().unwrap();
*is_aborted = true;
} }
unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>( unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>(

View file

@ -3,3 +3,4 @@ import { createSuite } from './test-util.mjs'
await createSuite('reference') await createSuite('reference')
await createSuite('tokio-future') await createSuite('tokio-future')
await createSuite('serde') await createSuite('serde')
await createSuite('tsfn')

View file

@ -1,4 +1,9 @@
use napi::{bindgen_prelude::*, Env}; use std::thread::spawn;
use napi::{
bindgen_prelude::*,
threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
#[macro_use] #[macro_use]
extern crate napi_derive; extern crate napi_derive;
@ -111,10 +116,28 @@ impl MemoryHolder {
pub struct ChildReference(SharedReference<MemoryHolder, ChildHolder>); pub struct ChildReference(SharedReference<MemoryHolder, ChildHolder>);
#[napi] #[napi]
impl ChildReference { impl ChildReference {
#[napi] #[napi]
pub fn count(&self) -> u32 { pub fn count(&self) -> u32 {
self.0.count() as u32 self.0.count() as u32
} }
} }
#[napi]
pub fn leaking_func(env: Env, func: JsFunction) -> napi::Result<()> {
let mut tsfn: ThreadsafeFunction<String> =
func.create_threadsafe_function(0, |mut ctx: ThreadSafeCallContext<String>| {
ctx.env.adjust_external_memory(ctx.value.len() as i64)?;
ctx
.env
.create_string_from_std(ctx.value)
.map(|js_string| vec![js_string])
})?;
tsfn.unref(&env)?;
spawn(move || {
tsfn.call(Ok("foo".into()), ThreadsafeFunctionCallMode::Blocking);
});
Ok(())
}

22
memory-testing/tsfn.mjs Normal file
View file

@ -0,0 +1,22 @@
import { createRequire } from 'module'
import { setTimeout } from 'timers/promises'
import { displayMemoryUsageFromNode } from './util.mjs'
const initialMemoryUsage = process.memoryUsage()
const require = createRequire(import.meta.url)
const api = require(`./index.node`)
let i = 1
// eslint-disable-next-line no-constant-condition
while (true) {
api.leakingFunc(() => {})
if (i % 100000 === 0) {
await setTimeout(100)
global.gc?.()
displayMemoryUsageFromNode(initialMemoryUsage)
}
i++
}