From 0af728a6016719ad2b6b6d0ba77bdc1406453d38 Mon Sep 17 00:00:00 2001 From: HotQ Date: Tue, 14 Mar 2023 21:31:52 +0800 Subject: [PATCH] fix(napi): prevent access to tsfn-raw after tsfn finalized(#1514) (#1515) --- crates/napi/src/threadsafe_function.rs | 78 +++++++++++++++++++------- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/crates/napi/src/threadsafe_function.rs b/crates/napi/src/threadsafe_function.rs index 736b5c58..7f4953be 100644 --- a/crates/napi/src/threadsafe_function.rs +++ b/crates/napi/src/threadsafe_function.rs @@ -4,8 +4,9 @@ use std::convert::Into; use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::c_void; -use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::pin::Pin; +use std::ptr::{self, null_mut}; +use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; use std::sync::{Arc, RwLock}; use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue}; @@ -97,7 +98,7 @@ type_level_enum! { } struct ThreadsafeFunctionHandle { - raw: sys::napi_threadsafe_function, + raw: AtomicPtr, aborted: RwLock, referred: AtomicBool, } @@ -105,6 +106,31 @@ struct ThreadsafeFunctionHandle { unsafe impl Send for ThreadsafeFunctionHandle {} unsafe impl Sync for ThreadsafeFunctionHandle {} +impl ThreadsafeFunctionHandle { + /// create a pinned Arc to hold the `ThreadsafeFunctionHandle` + fn new(raw: sys::napi_threadsafe_function) -> Pin> { + Arc::pin(Self { + raw: AtomicPtr::new(raw), + aborted: RwLock::new(false), + referred: AtomicBool::new(true), + }) + } + + fn null() -> Pin> { + Self::new(null_mut()) + } + + fn get_raw(&self) -> sys::napi_threadsafe_function { + let raw = self.raw.load(Ordering::SeqCst); + assert!(!raw.is_null()); + raw + } + + fn set_raw(&self, raw: sys::napi_threadsafe_function) { + self.raw.store(raw, Ordering::SeqCst) + } +} + impl Drop for ThreadsafeFunctionHandle { fn drop(&mut self) { let aborted_guard = self @@ -113,7 +139,10 @@ impl Drop for ThreadsafeFunctionHandle { .expect("Threadsafe Function aborted lock failed"); if !*aborted_guard && self.referred.load(Ordering::Acquire) { let release_status = unsafe { - sys::napi_release_threadsafe_function(self.raw, sys::ThreadsafeFunctionReleaseMode::release) + sys::napi_release_threadsafe_function( + self.get_raw(), + sys::ThreadsafeFunctionReleaseMode::release, + ) }; assert!( release_status == sys::Status::napi_ok, @@ -186,7 +215,7 @@ struct ThreadsafeFunctionCallJsBackData { /// } /// ``` pub struct ThreadsafeFunction { - handle: Arc, + handle: Pin>, _phantom: PhantomData<(T, ES)>, } @@ -299,6 +328,7 @@ impl ThreadsafeFunction { let mut raw_tsfn = ptr::null_mut(); let callback_ptr = Box::into_raw(Box::new(callback)); + let handle = ThreadsafeFunctionHandle::null(); check_status!(unsafe { sys::napi_create_threadsafe_function( env, @@ -307,20 +337,17 @@ impl ThreadsafeFunction { async_resource_name, max_queue_size, 1, - ptr::null_mut(), + Arc::into_raw(Pin::into_inner(handle.clone())) as *mut c_void, // pass handler to thread_finalize_cb Some(thread_finalize_cb::), callback_ptr.cast(), Some(call_js_cb::), &mut raw_tsfn, ) })?; + handle.set_raw(raw_tsfn); Ok(ThreadsafeFunction { - handle: Arc::new(ThreadsafeFunctionHandle { - raw: raw_tsfn, - aborted: RwLock::new(false), - referred: AtomicBool::new(true), - }), + handle, _phantom: PhantomData, }) } @@ -336,7 +363,7 @@ impl ThreadsafeFunction { .read() .expect("Threadsafe Function aborted lock failed"); if !*aborted_guard && !self.handle.referred.load(Ordering::Acquire) { - check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.raw) })?; + check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?; self.handle.referred.store(true, Ordering::Release); } Ok(()) @@ -351,7 +378,7 @@ impl ThreadsafeFunction { .read() .expect("Threadsafe Function aborted lock failed"); if !*aborted_guard && self.handle.referred.load(Ordering::Acquire) { - check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.raw) })?; + check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?; self.handle.referred.store(false, Ordering::Release); } Ok(()) @@ -375,7 +402,7 @@ impl ThreadsafeFunction { if !*aborted_guard { check_status!(unsafe { sys::napi_release_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), sys::ThreadsafeFunctionReleaseMode::abort, ) })?; @@ -386,7 +413,7 @@ impl ThreadsafeFunction { /// Get the raw `ThreadSafeFunction` pointer pub fn raw(&self) -> sys::napi_threadsafe_function { - self.handle.raw + self.handle.get_raw() } } @@ -396,7 +423,7 @@ impl ThreadsafeFunction { pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) -> Status { unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(value.map(|data| { ThreadsafeFunctionCallJsBackData { data, @@ -419,7 +446,7 @@ impl ThreadsafeFunction { ) -> Status { unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(value.map(|data| { ThreadsafeFunctionCallJsBackData { data, @@ -441,7 +468,7 @@ impl ThreadsafeFunction { let (sender, receiver) = tokio::sync::oneshot::channel::(); check_status!(unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(value.map(|data| { ThreadsafeFunctionCallJsBackData { data, @@ -471,7 +498,7 @@ impl ThreadsafeFunction { pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { data: value, call_variant: ThreadsafeFunctionCallVariant::Direct, @@ -492,7 +519,7 @@ impl ThreadsafeFunction { ) -> Status { unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { data: value, call_variant: ThreadsafeFunctionCallVariant::WithCallback, @@ -512,7 +539,7 @@ impl ThreadsafeFunction { let (sender, receiver) = tokio::sync::oneshot::channel::(); check_status!(unsafe { sys::napi_call_threadsafe_function( - self.handle.raw, + self.handle.get_raw(), Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { data: value, call_variant: ThreadsafeFunctionCallVariant::WithCallback, @@ -542,6 +569,15 @@ unsafe extern "C" fn thread_finalize_cb( ) where R: 'static + Send + FnMut(ThreadSafeCallContext) -> Result>, { + let handle = unsafe { Arc::from_raw(finalize_data.cast::()) }; + let mut aborted_guard = handle + .aborted + .write() + .expect("Threadsafe Function Handle aborted lock failed"); + if !*aborted_guard { + *aborted_guard = true; + } + // cleanup drop(unsafe { Box::::from_raw(finalize_hint.cast()) }); }