fix(napi): prevent access to tsfn-raw after tsfn finalized(#1514) (#1515)

This commit is contained in:
HotQ 2023-03-14 21:31:52 +08:00 committed by GitHub
parent c1072462a5
commit 0af728a601
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -4,8 +4,9 @@ use std::convert::Into;
use std::ffi::CString; use std::ffi::CString;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::c_void; use std::os::raw::c_void;
use std::ptr; use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering}; use std::ptr::{self, null_mut};
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue}; use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
@ -97,7 +98,7 @@ type_level_enum! {
} }
struct ThreadsafeFunctionHandle { struct ThreadsafeFunctionHandle {
raw: sys::napi_threadsafe_function, raw: AtomicPtr<sys::napi_threadsafe_function__>,
aborted: RwLock<bool>, aborted: RwLock<bool>,
referred: AtomicBool, referred: AtomicBool,
} }
@ -105,6 +106,31 @@ struct ThreadsafeFunctionHandle {
unsafe impl Send for ThreadsafeFunctionHandle {} unsafe impl Send for ThreadsafeFunctionHandle {}
unsafe impl Sync for ThreadsafeFunctionHandle {} unsafe impl Sync for ThreadsafeFunctionHandle {}
impl ThreadsafeFunctionHandle {
/// create a pinned Arc to hold the `ThreadsafeFunctionHandle`
fn new(raw: sys::napi_threadsafe_function) -> Pin<Arc<Self>> {
Arc::pin(Self {
raw: AtomicPtr::new(raw),
aborted: RwLock::new(false),
referred: AtomicBool::new(true),
})
}
fn null() -> Pin<Arc<Self>> {
Self::new(null_mut())
}
fn get_raw(&self) -> sys::napi_threadsafe_function {
let raw = self.raw.load(Ordering::SeqCst);
assert!(!raw.is_null());
raw
}
fn set_raw(&self, raw: sys::napi_threadsafe_function) {
self.raw.store(raw, Ordering::SeqCst)
}
}
impl Drop for ThreadsafeFunctionHandle { impl Drop for ThreadsafeFunctionHandle {
fn drop(&mut self) { fn drop(&mut self) {
let aborted_guard = self let aborted_guard = self
@ -113,7 +139,10 @@ impl Drop for ThreadsafeFunctionHandle {
.expect("Threadsafe Function aborted lock failed"); .expect("Threadsafe Function aborted lock failed");
if !*aborted_guard && self.referred.load(Ordering::Acquire) { if !*aborted_guard && self.referred.load(Ordering::Acquire) {
let release_status = unsafe { let release_status = unsafe {
sys::napi_release_threadsafe_function(self.raw, sys::ThreadsafeFunctionReleaseMode::release) sys::napi_release_threadsafe_function(
self.get_raw(),
sys::ThreadsafeFunctionReleaseMode::release,
)
}; };
assert!( assert!(
release_status == sys::Status::napi_ok, release_status == sys::Status::napi_ok,
@ -186,7 +215,7 @@ struct ThreadsafeFunctionCallJsBackData<T> {
/// } /// }
/// ``` /// ```
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> { pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
handle: Arc<ThreadsafeFunctionHandle>, handle: Pin<Arc<ThreadsafeFunctionHandle>>,
_phantom: PhantomData<(T, ES)>, _phantom: PhantomData<(T, ES)>,
} }
@ -299,6 +328,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
let mut raw_tsfn = ptr::null_mut(); let mut raw_tsfn = ptr::null_mut();
let callback_ptr = Box::into_raw(Box::new(callback)); let callback_ptr = Box::into_raw(Box::new(callback));
let handle = ThreadsafeFunctionHandle::null();
check_status!(unsafe { check_status!(unsafe {
sys::napi_create_threadsafe_function( sys::napi_create_threadsafe_function(
env, env,
@ -307,20 +337,17 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
async_resource_name, async_resource_name,
max_queue_size, max_queue_size,
1, 1,
ptr::null_mut(), Arc::into_raw(Pin::into_inner(handle.clone())) as *mut c_void, // pass handler to thread_finalize_cb
Some(thread_finalize_cb::<T, V, R>), Some(thread_finalize_cb::<T, V, R>),
callback_ptr.cast(), callback_ptr.cast(),
Some(call_js_cb::<T, V, R, ES>), Some(call_js_cb::<T, V, R, ES>),
&mut raw_tsfn, &mut raw_tsfn,
) )
})?; })?;
handle.set_raw(raw_tsfn);
Ok(ThreadsafeFunction { Ok(ThreadsafeFunction {
handle: Arc::new(ThreadsafeFunctionHandle { handle,
raw: raw_tsfn,
aborted: RwLock::new(false),
referred: AtomicBool::new(true),
}),
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
@ -336,7 +363,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
.read() .read()
.expect("Threadsafe Function aborted lock failed"); .expect("Threadsafe Function aborted lock failed");
if !*aborted_guard && !self.handle.referred.load(Ordering::Acquire) { if !*aborted_guard && !self.handle.referred.load(Ordering::Acquire) {
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.raw) })?; check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?;
self.handle.referred.store(true, Ordering::Release); self.handle.referred.store(true, Ordering::Release);
} }
Ok(()) Ok(())
@ -351,7 +378,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
.read() .read()
.expect("Threadsafe Function aborted lock failed"); .expect("Threadsafe Function aborted lock failed");
if !*aborted_guard && self.handle.referred.load(Ordering::Acquire) { if !*aborted_guard && self.handle.referred.load(Ordering::Acquire) {
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.raw) })?; check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?;
self.handle.referred.store(false, Ordering::Release); self.handle.referred.store(false, Ordering::Release);
} }
Ok(()) Ok(())
@ -375,7 +402,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
if !*aborted_guard { if !*aborted_guard {
check_status!(unsafe { check_status!(unsafe {
sys::napi_release_threadsafe_function( sys::napi_release_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
sys::ThreadsafeFunctionReleaseMode::abort, sys::ThreadsafeFunctionReleaseMode::abort,
) )
})?; })?;
@ -386,7 +413,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
/// Get the raw `ThreadSafeFunction` pointer /// Get the raw `ThreadSafeFunction` pointer
pub fn raw(&self) -> sys::napi_threadsafe_function { pub fn raw(&self) -> sys::napi_threadsafe_function {
self.handle.raw self.handle.get_raw()
} }
} }
@ -396,7 +423,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(value.map(|data| { Box::into_raw(Box::new(value.map(|data| {
ThreadsafeFunctionCallJsBackData { ThreadsafeFunctionCallJsBackData {
data, data,
@ -419,7 +446,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
) -> Status { ) -> Status {
unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(value.map(|data| { Box::into_raw(Box::new(value.map(|data| {
ThreadsafeFunctionCallJsBackData { ThreadsafeFunctionCallJsBackData {
data, data,
@ -441,7 +468,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
let (sender, receiver) = tokio::sync::oneshot::channel::<D>(); let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
check_status!(unsafe { check_status!(unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(value.map(|data| { Box::into_raw(Box::new(value.map(|data| {
ThreadsafeFunctionCallJsBackData { ThreadsafeFunctionCallJsBackData {
data, data,
@ -471,7 +498,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
data: value, data: value,
call_variant: ThreadsafeFunctionCallVariant::Direct, call_variant: ThreadsafeFunctionCallVariant::Direct,
@ -492,7 +519,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
) -> Status { ) -> Status {
unsafe { unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
data: value, data: value,
call_variant: ThreadsafeFunctionCallVariant::WithCallback, call_variant: ThreadsafeFunctionCallVariant::WithCallback,
@ -512,7 +539,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
let (sender, receiver) = tokio::sync::oneshot::channel::<D>(); let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
check_status!(unsafe { check_status!(unsafe {
sys::napi_call_threadsafe_function( sys::napi_call_threadsafe_function(
self.handle.raw, self.handle.get_raw(),
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
data: value, data: value,
call_variant: ThreadsafeFunctionCallVariant::WithCallback, call_variant: ThreadsafeFunctionCallVariant::WithCallback,
@ -542,6 +569,15 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
) where ) where
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>, R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
{ {
let handle = unsafe { Arc::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()) };
let mut aborted_guard = handle
.aborted
.write()
.expect("Threadsafe Function Handle aborted lock failed");
if !*aborted_guard {
*aborted_guard = true;
}
// cleanup // cleanup
drop(unsafe { Box::<R>::from_raw(finalize_hint.cast()) }); drop(unsafe { Box::<R>::from_raw(finalize_hint.cast()) });
} }