From d8cfcfdfda5852679ec85b8e54d8acc561c67beb Mon Sep 17 00:00:00 2001 From: Bo Date: Tue, 28 Mar 2023 20:54:55 +0800 Subject: [PATCH] fix(napi): ensure that `napi_call_threadsafe_function` cannot be called after abort (#1533) * refactor(napi): reduce boilerplate code for accessing `aborted` lock * refactor: ensure that `napi_call_threadsafe_function` cannot be called after abort --------- Co-authored-by: LongYinan --- crates/napi/src/threadsafe_function.rs | 389 ++++++++++++++----------- 1 file changed, 216 insertions(+), 173 deletions(-) diff --git a/crates/napi/src/threadsafe_function.rs b/crates/napi/src/threadsafe_function.rs index 28850903..65b8c4c9 100644 --- a/crates/napi/src/threadsafe_function.rs +++ b/crates/napi/src/threadsafe_function.rs @@ -6,7 +6,7 @@ use std::marker::PhantomData; use std::os::raw::c_void; use std::ptr::{self, null_mut}; use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; -use std::sync::{Arc, RwLock, Weak}; +use std::sync::{Arc, RwLock, RwLockWriteGuard, Weak}; use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue}; use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status}; @@ -112,6 +112,30 @@ impl ThreadsafeFunctionHandle { }) } + /// Lock `aborted` with read access, call `f` with the value of `aborted`, then unlock it + fn with_read_aborted(&self, f: F) -> RT + where + F: FnOnce(bool) -> RT, + { + let aborted_guard = self + .aborted + .read() + .expect("Threadsafe Function aborted lock failed"); + f(*aborted_guard) + } + + /// Lock `aborted` with write access, call `f` with the `RwLockWriteGuard`, then unlock it + fn with_write_aborted(&self, f: F) -> RT + where + F: FnOnce(RwLockWriteGuard) -> RT, + { + let aborted_guard = self + .aborted + .write() + .expect("Threadsafe Function aborted lock failed"); + f(aborted_guard) + } + fn null() -> Arc { Self::new(null_mut()) } @@ -127,23 +151,21 @@ impl ThreadsafeFunctionHandle { impl Drop for ThreadsafeFunctionHandle { fn drop(&mut self) { - let aborted_guard = self - .aborted - .read() - .expect("Threadsafe Function aborted lock failed"); - if !*aborted_guard { - let release_status = unsafe { - sys::napi_release_threadsafe_function( - self.get_raw(), - sys::ThreadsafeFunctionReleaseMode::release, - ) - }; - assert!( - release_status == sys::Status::napi_ok, - "Threadsafe Function release failed {}", - Status::from(release_status) - ); - } + self.with_read_aborted(|aborted| { + if !aborted { + let release_status = unsafe { + sys::napi_release_threadsafe_function( + self.get_raw(), + sys::ThreadsafeFunctionReleaseMode::release, + ) + }; + assert!( + release_status == sys::Status::napi_ok, + "Threadsafe Function release failed {}", + Status::from(release_status) + ); + } + }) } } @@ -215,19 +237,16 @@ pub struct ThreadsafeFunction Clone for ThreadsafeFunction { fn clone(&self) -> Self { - let aborted_guard = self - .handle - .aborted - .read() - .expect("Threadsafe Function aborted lock failed"); - if *aborted_guard { - panic!("ThreadsafeFunction was aborted, can not clone it"); - } + self.handle.with_read_aborted(|aborted| { + if aborted { + panic!("ThreadsafeFunction was aborted, can not clone it"); + }; - Self { - handle: self.handle.clone(), - _phantom: PhantomData, - } + Self { + handle: self.handle.clone(), + _phantom: PhantomData, + } + }) } } @@ -351,58 +370,46 @@ impl ThreadsafeFunction { /// /// "ref" is a keyword so that we use "refer" here. pub fn refer(&mut self, env: &Env) -> Result<()> { - let aborted_guard = self - .handle - .aborted - .read() - .expect("Threadsafe Function aborted lock failed"); - if !*aborted_guard && !self.handle.referred.load(Ordering::Relaxed) { - check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?; - self.handle.referred.store(true, Ordering::Relaxed); - } - Ok(()) + self.handle.with_read_aborted(|aborted| { + if !aborted && !self.handle.referred.load(Ordering::Relaxed) { + check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?; + self.handle.referred.store(true, Ordering::Relaxed); + } + Ok(()) + }) } /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// for more information. pub fn unref(&mut self, env: &Env) -> Result<()> { - let aborted_guard = self - .handle - .aborted - .read() - .expect("Threadsafe Function aborted lock failed"); - if !*aborted_guard && self.handle.referred.load(Ordering::Relaxed) { - check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?; - self.handle.referred.store(false, Ordering::Relaxed); - } - Ok(()) + self.handle.with_read_aborted(|aborted| { + if !aborted && self.handle.referred.load(Ordering::Relaxed) { + check_status!(unsafe { + sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) + })?; + self.handle.referred.store(false, Ordering::Relaxed); + } + Ok(()) + }) } pub fn aborted(&self) -> bool { - let aborted_guard = self - .handle - .aborted - .read() - .expect("Threadsafe Function aborted lock failed"); - *aborted_guard + self.handle.with_read_aborted(|aborted| aborted) } pub fn abort(self) -> Result<()> { - let mut aborted_guard = self - .handle - .aborted - .write() - .expect("Threadsafe Function aborted lock failed"); - if !*aborted_guard { - check_status!(unsafe { - sys::napi_release_threadsafe_function( - self.handle.get_raw(), - sys::ThreadsafeFunctionReleaseMode::abort, - ) - })?; - *aborted_guard = true; - } - Ok(()) + self.handle.with_write_aborted(|mut aborted_guard| { + if !*aborted_guard { + check_status!(unsafe { + sys::napi_release_threadsafe_function( + self.handle.get_raw(), + sys::ThreadsafeFunctionReleaseMode::abort, + ) + })?; + *aborted_guard = true; + } + Ok(()) + }) } /// Get the raw `ThreadSafeFunction` pointer @@ -415,21 +422,27 @@ impl ThreadsafeFunction { /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. pub fn call(&self, value: Result, mode: ThreadsafeFunctionCallMode) -> Status { - unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(value.map(|data| { - ThreadsafeFunctionCallJsBackData { - data, - call_variant: ThreadsafeFunctionCallVariant::Direct, - callback: Box::new(|_d: JsUnknown| Ok(())), - } - }))) - .cast(), - mode.into(), - ) - } - .into() + self.handle.with_read_aborted(|aborted| { + if aborted { + return Status::Closing; + } + + unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(value.map(|data| { + ThreadsafeFunctionCallJsBackData { + data, + call_variant: ThreadsafeFunctionCallVariant::Direct, + callback: Box::new(|_d: JsUnknown| Ok(())), + } + }))) + .cast(), + mode.into(), + ) + } + .into() + }) } pub fn call_with_return_value Result<()>>( @@ -438,47 +451,60 @@ impl ThreadsafeFunction { mode: ThreadsafeFunctionCallMode, cb: F, ) -> Status { - unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(value.map(|data| { - ThreadsafeFunctionCallJsBackData { - data, - call_variant: ThreadsafeFunctionCallVariant::WithCallback, - callback: Box::new(move |d: JsUnknown| { - D::from_napi_value(d.0.env, d.0.value).and_then(cb) - }), - } - }))) - .cast(), - mode.into(), - ) - } - .into() + self.handle.with_read_aborted(|aborted| { + if aborted { + return Status::Closing; + } + + unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(value.map(|data| { + ThreadsafeFunctionCallJsBackData { + data, + call_variant: ThreadsafeFunctionCallVariant::WithCallback, + callback: Box::new(move |d: JsUnknown| { + D::from_napi_value(d.0.env, d.0.value).and_then(cb) + }), + } + }))) + .cast(), + mode.into(), + ) + } + .into() + }) } #[cfg(feature = "tokio_rt")] pub async fn call_async(&self, value: Result) -> Result { let (sender, receiver) = tokio::sync::oneshot::channel::(); - check_status!(unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(value.map(|data| { - ThreadsafeFunctionCallJsBackData { - data, - call_variant: ThreadsafeFunctionCallVariant::WithCallback, - callback: Box::new(move |d: JsUnknown| { - D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { - sender.send(d).map_err(|_| { - crate::Error::from_reason("Failed to send return value to tokio sender") + + self.handle.with_read_aborted(|aborted| { + if aborted { + return Err(crate::Error::from_status(Status::Closing)); + } + + check_status!(unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(value.map(|data| { + ThreadsafeFunctionCallJsBackData { + data, + call_variant: ThreadsafeFunctionCallVariant::WithCallback, + callback: Box::new(move |d: JsUnknown| { + D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { + sender.send(d).map_err(|_| { + crate::Error::from_reason("Failed to send return value to tokio sender") + }) }) - }) - }), - } - }))) - .cast(), - ThreadsafeFunctionCallMode::NonBlocking.into(), - ) + }), + } + }))) + .cast(), + ThreadsafeFunctionCallMode::NonBlocking.into(), + ) + }) })?; receiver .await @@ -490,19 +516,25 @@ impl ThreadsafeFunction { /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// for more information. pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { - unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { - data: value, - call_variant: ThreadsafeFunctionCallVariant::Direct, - callback: Box::new(|_d: JsUnknown| Ok(())), - })) - .cast(), - mode.into(), - ) - } - .into() + self.handle.with_read_aborted(|aborted| { + if aborted { + return Status::Closing; + } + + unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { + data: value, + call_variant: ThreadsafeFunctionCallVariant::Direct, + callback: Box::new(|_d: JsUnknown| Ok(())), + })) + .cast(), + mode.into(), + ) + } + .into() + }) } pub fn call_with_return_value Result<()>>( @@ -511,44 +543,58 @@ impl ThreadsafeFunction { mode: ThreadsafeFunctionCallMode, cb: F, ) -> Status { - unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { - data: value, - call_variant: ThreadsafeFunctionCallVariant::WithCallback, - callback: Box::new(move |d: JsUnknown| { - D::from_napi_value(d.0.env, d.0.value).and_then(cb) - }), - })) - .cast(), - mode.into(), - ) - } - .into() + self.handle.with_read_aborted(|aborted| { + if aborted { + return Status::Closing; + } + + unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { + data: value, + call_variant: ThreadsafeFunctionCallVariant::WithCallback, + callback: Box::new(move |d: JsUnknown| { + D::from_napi_value(d.0.env, d.0.value).and_then(cb) + }), + })) + .cast(), + mode.into(), + ) + } + .into() + }) } #[cfg(feature = "tokio_rt")] pub async fn call_async(&self, value: T) -> Result { let (sender, receiver) = tokio::sync::oneshot::channel::(); - check_status!(unsafe { - sys::napi_call_threadsafe_function( - self.handle.get_raw(), - Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { - data: value, - call_variant: ThreadsafeFunctionCallVariant::WithCallback, - callback: Box::new(move |d: JsUnknown| { - D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { - sender.send(d).map_err(|_| { - crate::Error::from_reason("Failed to send return value to tokio sender") + + self.handle.with_read_aborted(|aborted| { + if aborted { + return Err(crate::Error::from_status(Status::Closing)); + } + + check_status!(unsafe { + sys::napi_call_threadsafe_function( + self.handle.get_raw(), + Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { + data: value, + call_variant: ThreadsafeFunctionCallVariant::WithCallback, + callback: Box::new(move |d: JsUnknown| { + D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { + sender.send(d).map_err(|_| { + crate::Error::from_reason("Failed to send return value to tokio sender") + }) }) - }) - }), - })) - .cast(), - ThreadsafeFunctionCallMode::NonBlocking.into(), - ) + }), + })) + .cast(), + ThreadsafeFunctionCallMode::NonBlocking.into(), + ) + }) })?; + receiver .await .map_err(|err| crate::Error::new(Status::GenericFailure, format!("{}", err))) @@ -567,14 +613,11 @@ unsafe extern "C" fn thread_finalize_cb( unsafe { Weak::from_raw(finalize_data.cast::()).upgrade() }; if let Some(handle) = handle_option { - let mut aborted_guard = handle - .aborted - .write() - .expect("Threadsafe Function Handle aborted lock failed"); - - if !*aborted_guard { - *aborted_guard = true; - } + handle.with_write_aborted(|mut aborted_guard| { + if !*aborted_guard { + *aborted_guard = true; + } + }); } // cleanup