fix(napi): memory leak in ThreadsafeFunction

This commit is contained in:
LongYinan 2022-07-08 00:09:14 +08:00
parent d0a9cbfa86
commit 4dfc770c2a
No known key found for this signature in database
GPG key ID: C3666B7FC82ADAD7

View file

@ -5,7 +5,7 @@ use std::ffi::CString;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::c_void; use std::os::raw::c_void;
use std::ptr; use std::ptr;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use crate::bindgen_runtime::ToNapiValue; use crate::bindgen_runtime::ToNapiValue;
@ -147,7 +147,6 @@ type_level_enum! {
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> { pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
raw_tsfn: sys::napi_threadsafe_function, raw_tsfn: sys::napi_threadsafe_function,
aborted: Arc<AtomicBool>, aborted: Arc<AtomicBool>,
ref_count: Arc<AtomicUsize>,
_phantom: PhantomData<(T, ES)>, _phantom: PhantomData<(T, ES)>,
} }
@ -164,7 +163,6 @@ impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
Self { Self {
raw_tsfn: self.raw_tsfn, raw_tsfn: self.raw_tsfn,
aborted: Arc::clone(&self.aborted), aborted: Arc::clone(&self.aborted),
ref_count: Arc::clone(&self.ref_count),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -196,6 +194,13 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
let initial_thread_count = 1usize; let initial_thread_count = 1usize;
let mut raw_tsfn = ptr::null_mut(); let mut raw_tsfn = ptr::null_mut();
let ptr = Box::into_raw(Box::new(callback)) as *mut c_void; let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
let aborted = Arc::new(AtomicBool::new(false));
let aborted_ptr = Arc::into_raw(aborted.clone()) as *mut c_void;
// `aborted_ptr` is passed into both `finalize_callback` and `env_cleanup_callback`.
// So increase strong count here to prevent it from being dropped twice.
unsafe {
Arc::increment_strong_count(aborted_ptr);
}
check_status!(unsafe { check_status!(unsafe {
sys::napi_create_threadsafe_function( sys::napi_create_threadsafe_function(
env, env,
@ -206,20 +211,17 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
initial_thread_count, initial_thread_count,
ptr, ptr,
Some(thread_finalize_cb::<T, V, R>), Some(thread_finalize_cb::<T, V, R>),
ptr, aborted_ptr,
Some(call_js_cb::<T, V, R, ES>), Some(call_js_cb::<T, V, R, ES>),
&mut raw_tsfn, &mut raw_tsfn,
) )
})?; })?;
let aborted = Arc::new(AtomicBool::new(false));
let aborted_ptr = Arc::into_raw(aborted.clone()) as *mut c_void;
check_status!(unsafe { sys::napi_add_env_cleanup_hook(env, Some(cleanup_cb), aborted_ptr) })?; check_status!(unsafe { sys::napi_add_env_cleanup_hook(env, Some(cleanup_cb), aborted_ptr) })?;
Ok(ThreadsafeFunction { Ok(ThreadsafeFunction {
raw_tsfn, raw_tsfn,
aborted, aborted,
ref_count: Arc::new(AtomicUsize::new(initial_thread_count)),
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
@ -235,7 +237,6 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
"Can not ref, Thread safe function already aborted".to_string(), "Can not ref, Thread safe function already aborted".to_string(),
)); ));
} }
self.ref_count.fetch_add(1, Ordering::AcqRel);
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) }) check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
} }
@ -248,12 +249,11 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
"Can not unref, Thread safe function already aborted".to_string(), "Can not unref, Thread safe function already aborted".to_string(),
)); ));
} }
self.ref_count.fetch_sub(1, Ordering::AcqRel);
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) }) check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) })
} }
pub fn aborted(&self) -> bool { pub fn aborted(&self) -> bool {
self.aborted.load(Ordering::Relaxed) self.aborted.load(Ordering::Acquire)
} }
pub fn abort(self) -> Result<()> { pub fn abort(self) -> Result<()> {
@ -280,14 +280,18 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
if self.aborted.load(Ordering::Acquire) { if self.aborted.load(Ordering::Acquire) {
return Status::Closing; return Status::Closing;
} }
unsafe { let status = unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
Box::into_raw(Box::new(value)) as *mut _, Box::into_raw(Box::new(value)) as *mut _,
mode.into(), mode.into(),
) )
} }
.into() .into();
if status == Status::Closing {
self.aborted.store(true, Ordering::Release);
}
status
} }
} }
@ -298,20 +302,24 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
if self.aborted.load(Ordering::Acquire) { if self.aborted.load(Ordering::Acquire) {
return Status::Closing; return Status::Closing;
} }
unsafe { let status = unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
Box::into_raw(Box::new(value)) as *mut _, Box::into_raw(Box::new(value)) as *mut _,
mode.into(), mode.into(),
) )
} }
.into() .into();
if status == Status::Closing {
self.aborted.store(true, Ordering::Release);
}
status
} }
} }
impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> { impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
fn drop(&mut self) { fn drop(&mut self) {
if !self.aborted.load(Ordering::Acquire) && self.ref_count.load(Ordering::Acquire) > 0usize { if !self.aborted.load(Ordering::Acquire) {
let release_status = unsafe { let release_status = unsafe {
sys::napi_release_threadsafe_function( sys::napi_release_threadsafe_function(
self.raw_tsfn, self.raw_tsfn,
@ -328,18 +336,20 @@ impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) { unsafe extern "C" fn cleanup_cb(cleanup_data: *mut c_void) {
let aborted = unsafe { Arc::<AtomicBool>::from_raw(cleanup_data.cast()) }; let aborted = unsafe { Arc::<AtomicBool>::from_raw(cleanup_data.cast()) };
aborted.store(true, Ordering::SeqCst); aborted.store(true, Ordering::Release);
} }
unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>( unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
_raw_env: sys::napi_env, _raw_env: sys::napi_env,
finalize_data: *mut c_void, finalize_data: *mut c_void,
_finalize_hint: *mut c_void, finalize_hint: *mut c_void,
) where ) where
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>, R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
{ {
// cleanup // cleanup
drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) }); drop(unsafe { Box::<R>::from_raw(finalize_data.cast()) });
let aborted = unsafe { Arc::<AtomicBool>::from_raw(finalize_hint.cast()) };
aborted.store(true, Ordering::Release);
} }
unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>( unsafe extern "C" fn call_js_cb<T: 'static, V: ToNapiValue, R, ES>(