fix(napi): use Mutex instead of Atomic in ThreadSafeFunction
This commit is contained in:
parent
a4448d3e24
commit
552ec43fae
4 changed files with 97 additions and 33 deletions
|
@ -5,8 +5,8 @@ use std::ffi::CString;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::os::raw::c_void;
|
use std::os::raw::c_void;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use crate::bindgen_runtime::ToNapiValue;
|
use crate::bindgen_runtime::ToNapiValue;
|
||||||
use crate::{check_status, sys, Env, Error, JsError, Result, Status};
|
use crate::{check_status, sys, Env, Error, JsError, Result, Status};
|
||||||
|
@ -146,21 +146,28 @@ type_level_enum! {
|
||||||
/// ```
|
/// ```
|
||||||
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
|
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
|
||||||
raw_tsfn: sys::napi_threadsafe_function,
|
raw_tsfn: sys::napi_threadsafe_function,
|
||||||
aborted: Arc<AtomicBool>,
|
aborted: Arc<Mutex<bool>>,
|
||||||
ref_count: Arc<AtomicUsize>,
|
ref_count: Arc<AtomicUsize>,
|
||||||
_phantom: PhantomData<(T, ES)>,
|
_phantom: PhantomData<(T, ES)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
|
impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
if !self.aborted.load(Ordering::Acquire) {
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if !*is_aborted {
|
||||||
let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) };
|
let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) };
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
acquire_status == sys::Status::napi_ok,
|
acquire_status == sys::Status::napi_ok,
|
||||||
"Acquire threadsafe function failed in clone"
|
"Acquire threadsafe function failed in clone"
|
||||||
);
|
);
|
||||||
|
} else {
|
||||||
|
panic!("ThreadsafeFunction was aborted, can not clone it");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.ref_count.fetch_add(1, Ordering::AcqRel);
|
||||||
|
|
||||||
|
drop(is_aborted);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
raw_tsfn: self.raw_tsfn,
|
raw_tsfn: self.raw_tsfn,
|
||||||
aborted: Arc::clone(&self.aborted),
|
aborted: Arc::clone(&self.aborted),
|
||||||
|
@ -196,6 +203,8 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
let initial_thread_count = 1usize;
|
let initial_thread_count = 1usize;
|
||||||
let mut raw_tsfn = ptr::null_mut();
|
let mut raw_tsfn = ptr::null_mut();
|
||||||
let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
|
let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
|
||||||
|
let aborted = Arc::new(Mutex::new(false));
|
||||||
|
let aborted_ptr = Arc::into_raw(aborted.clone());
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_create_threadsafe_function(
|
sys::napi_create_threadsafe_function(
|
||||||
env,
|
env,
|
||||||
|
@ -206,16 +215,12 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
initial_thread_count,
|
initial_thread_count,
|
||||||
ptr,
|
ptr,
|
||||||
Some(thread_finalize_cb::<T, V, R>),
|
Some(thread_finalize_cb::<T, V, R>),
|
||||||
ptr,
|
aborted_ptr as *mut c_void,
|
||||||
Some(call_js_cb::<T, V, R, ES>),
|
Some(call_js_cb::<T, V, R, ES>),
|
||||||
&mut raw_tsfn,
|
&mut raw_tsfn,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let aborted = Arc::new(AtomicBool::new(false));
|
|
||||||
let aborted_ptr = Arc::into_raw(aborted.clone()) as *mut c_void;
|
|
||||||
check_status!(unsafe { sys::napi_add_env_cleanup_hook(env, Some(cleanup_cb), aborted_ptr) })?;
|
|
||||||
|
|
||||||
Ok(ThreadsafeFunction {
|
Ok(ThreadsafeFunction {
|
||||||
raw_tsfn,
|
raw_tsfn,
|
||||||
aborted,
|
aborted,
|
||||||
|
@ -229,12 +234,14 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
///
|
///
|
||||||
/// "ref" is a keyword so that we use "refer" here.
|
/// "ref" is a keyword so that we use "refer" here.
|
||||||
pub fn refer(&mut self, env: &Env) -> Result<()> {
|
pub fn refer(&mut self, env: &Env) -> Result<()> {
|
||||||
if self.aborted.load(Ordering::Acquire) {
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if *is_aborted {
|
||||||
return Err(Error::new(
|
return Err(Error::new(
|
||||||
Status::Closing,
|
Status::Closing,
|
||||||
"Can not ref, Thread safe function already aborted".to_string(),
|
"Can not ref, Thread safe function already aborted".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
drop(is_aborted);
|
||||||
self.ref_count.fetch_add(1, Ordering::AcqRel);
|
self.ref_count.fetch_add(1, Ordering::AcqRel);
|
||||||
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
|
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
|
||||||
}
|
}
|
||||||
|
@ -242,7 +249,8 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
/// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
|
/// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
|
||||||
/// for more information.
|
/// for more information.
|
||||||
pub fn unref(&mut self, env: &Env) -> Result<()> {
|
pub fn unref(&mut self, env: &Env) -> Result<()> {
|
||||||
if self.aborted.load(Ordering::Acquire) {
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if *is_aborted {
|
||||||
return Err(Error::new(
|
return Err(Error::new(
|
||||||
Status::Closing,
|
Status::Closing,
|
||||||
"Can not unref, Thread safe function already aborted".to_string(),
|
"Can not unref, Thread safe function already aborted".to_string(),
|
||||||
|
@ -253,17 +261,21 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn aborted(&self) -> bool {
|
pub fn aborted(&self) -> bool {
|
||||||
self.aborted.load(Ordering::Relaxed)
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
*is_aborted
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn abort(self) -> Result<()> {
|
pub fn abort(self) -> Result<()> {
|
||||||
|
let mut is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if !*is_aborted {
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_release_threadsafe_function(
|
sys::napi_release_threadsafe_function(
|
||||||
self.raw_tsfn,
|
self.raw_tsfn,
|
||||||
sys::ThreadsafeFunctionReleaseMode::abort,
|
sys::ThreadsafeFunctionReleaseMode::abort,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
self.aborted.store(true, Ordering::Release);
|
}
|
||||||
|
*is_aborted = true;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -277,13 +289,14 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
||||||
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
||||||
/// for more information.
|
/// for more information.
|
||||||
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
|
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||||
if self.aborted.load(Ordering::Acquire) {
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if *is_aborted {
|
||||||
return Status::Closing;
|
return Status::Closing;
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.raw_tsfn,
|
self.raw_tsfn,
|
||||||
Box::into_raw(Box::new(value)) as *mut _,
|
Box::into_raw(Box::new(value)) as *mut c_void,
|
||||||
mode.into(),
|
mode.into(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -295,13 +308,14 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
||||||
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
||||||
/// for more information.
|
/// for more information.
|
||||||
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
|
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||||
if self.aborted.load(Ordering::Acquire) {
|
let is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if *is_aborted {
|
||||||
return Status::Closing;
|
return Status::Closing;
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.raw_tsfn,
|
self.raw_tsfn,
|
||||||
Box::into_raw(Box::new(value)) as *mut _,
|
Box::into_raw(Box::new(value)) as *mut c_void,
|
||||||
mode.into(),
|
mode.into(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -311,7 +325,8 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
||||||
|
|
||||||
impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
|
impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if !self.aborted.load(Ordering::Acquire) && self.ref_count.load(Ordering::Acquire) > 0usize {
|
let mut is_aborted = self.aborted.lock().unwrap();
|
||||||
|
if !*is_aborted && self.ref_count.load(Ordering::Acquire) <= 1 {
|
||||||
let release_status = unsafe {
|
let release_status = unsafe {
|
||||||
sys::napi_release_threadsafe_function(
|
sys::napi_release_threadsafe_function(
|
||||||
self.raw_tsfn,
|
self.raw_tsfn,
|
||||||
|
@ -320,26 +335,29 @@ impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
|
||||||
};
|
};
|
||||||
assert!(
|
assert!(
|
||||||
release_status == sys::Status::napi_ok,
|
release_status == sys::Status::napi_ok,
|
||||||
"Threadsafe Function release failed"
|
"Threadsafe Function release failed {:?}",
|
||||||
|
Status::from(release_status)
|
||||||
);
|
);
|
||||||
|
*is_aborted = true;
|
||||||
|
} else {
|
||||||
|
self.ref_count.fetch_sub(1, Ordering::Release);
|
||||||
}
|
}
|
||||||
|
drop(is_aborted);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) {
|
|
||||||
let aborted = unsafe { Arc::<AtomicBool>::from_raw(cleanup_data.cast()) };
|
|
||||||
aborted.store(true, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
|
unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
|
||||||
_raw_env: sys::napi_env,
|
_raw_env: sys::napi_env,
|
||||||
finalize_data: *mut c_void,
|
finalize_data: *mut c_void,
|
||||||
_finalize_hint: *mut c_void,
|
finalize_hint: *mut c_void,
|
||||||
) where
|
) where
|
||||||
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
|
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
|
||||||
{
|
{
|
||||||
// cleanup
|
// cleanup
|
||||||
drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) });
|
drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) });
|
||||||
|
let aborted = unsafe { Arc::<Mutex<bool>>::from_raw(finalize_hint.cast()) };
|
||||||
|
let mut is_aborted = aborted.lock().unwrap();
|
||||||
|
*is_aborted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>(
|
unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>(
|
||||||
|
|
|
@ -3,3 +3,4 @@ import { createSuite } from './test-util.mjs'
|
||||||
await createSuite('reference')
|
await createSuite('reference')
|
||||||
await createSuite('tokio-future')
|
await createSuite('tokio-future')
|
||||||
await createSuite('serde')
|
await createSuite('serde')
|
||||||
|
await createSuite('tsfn')
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
use napi::{bindgen_prelude::*, Env};
|
use std::thread::spawn;
|
||||||
|
|
||||||
|
use napi::{
|
||||||
|
bindgen_prelude::*,
|
||||||
|
threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunction, ThreadsafeFunctionCallMode},
|
||||||
|
};
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate napi_derive;
|
extern crate napi_derive;
|
||||||
|
@ -111,10 +116,28 @@ impl MemoryHolder {
|
||||||
pub struct ChildReference(SharedReference<MemoryHolder, ChildHolder>);
|
pub struct ChildReference(SharedReference<MemoryHolder, ChildHolder>);
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
|
|
||||||
impl ChildReference {
|
impl ChildReference {
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn count(&self) -> u32 {
|
pub fn count(&self) -> u32 {
|
||||||
self.0.count() as u32
|
self.0.count() as u32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn leaking_func(env: Env, func: JsFunction) -> napi::Result<()> {
|
||||||
|
let mut tsfn: ThreadsafeFunction<String> =
|
||||||
|
func.create_threadsafe_function(0, |mut ctx: ThreadSafeCallContext<String>| {
|
||||||
|
ctx.env.adjust_external_memory(ctx.value.len() as i64)?;
|
||||||
|
ctx
|
||||||
|
.env
|
||||||
|
.create_string_from_std(ctx.value)
|
||||||
|
.map(|js_string| vec![js_string])
|
||||||
|
})?;
|
||||||
|
|
||||||
|
tsfn.unref(&env)?;
|
||||||
|
spawn(move || {
|
||||||
|
tsfn.call(Ok("foo".into()), ThreadsafeFunctionCallMode::Blocking);
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
22
memory-testing/tsfn.mjs
Normal file
22
memory-testing/tsfn.mjs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import { createRequire } from 'module'
|
||||||
|
import { setTimeout } from 'timers/promises'
|
||||||
|
|
||||||
|
import { displayMemoryUsageFromNode } from './util.mjs'
|
||||||
|
|
||||||
|
const initialMemoryUsage = process.memoryUsage()
|
||||||
|
|
||||||
|
const require = createRequire(import.meta.url)
|
||||||
|
|
||||||
|
const api = require(`./index.node`)
|
||||||
|
|
||||||
|
let i = 1
|
||||||
|
// eslint-disable-next-line no-constant-condition
|
||||||
|
while (true) {
|
||||||
|
api.leakingFunc(() => {})
|
||||||
|
if (i % 100000 === 0) {
|
||||||
|
await setTimeout(100)
|
||||||
|
global.gc?.()
|
||||||
|
displayMemoryUsageFromNode(initialMemoryUsage)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
Loading…
Reference in a new issue