fix(napi): ensure that napi_call_threadsafe_function cannot be called after abort (#1533)

* refactor(napi): reduce boilerplate code for accessing `aborted` lock

* refactor: ensure that `napi_call_threadsafe_function` cannot be called after abort

---------

Co-authored-by: LongYinan <lynweklm@gmail.com>
This commit is contained in:
Bo 2023-03-28 20:54:55 +08:00 committed by GitHub
parent e47c13f177
commit d8cfcfdfda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,7 @@ use std::marker::PhantomData;
use std::os::raw::c_void; use std::os::raw::c_void;
use std::ptr::{self, null_mut}; use std::ptr::{self, null_mut};
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::{Arc, RwLock, Weak}; use std::sync::{Arc, RwLock, RwLockWriteGuard, Weak};
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue}; use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status}; use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status};
@ -112,6 +112,30 @@ impl ThreadsafeFunctionHandle {
}) })
} }
/// Lock `aborted` with read access, call `f` with the value of `aborted`, then unlock it
fn with_read_aborted<RT, F>(&self, f: F) -> RT
where
F: FnOnce(bool) -> RT,
{
let aborted_guard = self
.aborted
.read()
.expect("Threadsafe Function aborted lock failed");
f(*aborted_guard)
}
/// Lock `aborted` with write access, call `f` with the `RwLockWriteGuard`, then unlock it
fn with_write_aborted<RT, F>(&self, f: F) -> RT
where
F: FnOnce(RwLockWriteGuard<bool>) -> RT,
{
let aborted_guard = self
.aborted
.write()
.expect("Threadsafe Function aborted lock failed");
f(aborted_guard)
}
fn null() -> Arc<Self> { fn null() -> Arc<Self> {
Self::new(null_mut()) Self::new(null_mut())
} }
@ -127,23 +151,21 @@ impl ThreadsafeFunctionHandle {
impl Drop for ThreadsafeFunctionHandle { impl Drop for ThreadsafeFunctionHandle {
fn drop(&mut self) { fn drop(&mut self) {
let aborted_guard = self self.with_read_aborted(|aborted| {
.aborted if !aborted {
.read() let release_status = unsafe {
.expect("Threadsafe Function aborted lock failed"); sys::napi_release_threadsafe_function(
if !*aborted_guard { self.get_raw(),
let release_status = unsafe { sys::ThreadsafeFunctionReleaseMode::release,
sys::napi_release_threadsafe_function( )
self.get_raw(), };
sys::ThreadsafeFunctionReleaseMode::release, assert!(
) release_status == sys::Status::napi_ok,
}; "Threadsafe Function release failed {}",
assert!( Status::from(release_status)
release_status == sys::Status::napi_ok, );
"Threadsafe Function release failed {}", }
Status::from(release_status) })
);
}
} }
} }
@ -215,19 +237,16 @@ pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::
impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> { impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let aborted_guard = self self.handle.with_read_aborted(|aborted| {
.handle if aborted {
.aborted panic!("ThreadsafeFunction was aborted, can not clone it");
.read() };
.expect("Threadsafe Function aborted lock failed");
if *aborted_guard {
panic!("ThreadsafeFunction was aborted, can not clone it");
}
Self { Self {
handle: self.handle.clone(), handle: self.handle.clone(),
_phantom: PhantomData, _phantom: PhantomData,
} }
})
} }
} }
@ -351,58 +370,46 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
/// ///
/// "ref" is a keyword so that we use "refer" here. /// "ref" is a keyword so that we use "refer" here.
pub fn refer(&mut self, env: &Env) -> Result<()> { pub fn refer(&mut self, env: &Env) -> Result<()> {
let aborted_guard = self self.handle.with_read_aborted(|aborted| {
.handle if !aborted && !self.handle.referred.load(Ordering::Relaxed) {
.aborted check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?;
.read() self.handle.referred.store(true, Ordering::Relaxed);
.expect("Threadsafe Function aborted lock failed"); }
if !*aborted_guard && !self.handle.referred.load(Ordering::Relaxed) { Ok(())
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?; })
self.handle.referred.store(true, Ordering::Relaxed);
}
Ok(())
} }
/// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function) /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
/// for more information. /// for more information.
pub fn unref(&mut self, env: &Env) -> Result<()> { pub fn unref(&mut self, env: &Env) -> Result<()> {
let aborted_guard = self self.handle.with_read_aborted(|aborted| {
.handle if !aborted && self.handle.referred.load(Ordering::Relaxed) {
.aborted check_status!(unsafe {
.read() sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw())
.expect("Threadsafe Function aborted lock failed"); })?;
if !*aborted_guard && self.handle.referred.load(Ordering::Relaxed) { self.handle.referred.store(false, Ordering::Relaxed);
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?; }
self.handle.referred.store(false, Ordering::Relaxed); Ok(())
} })
Ok(())
} }
pub fn aborted(&self) -> bool { pub fn aborted(&self) -> bool {
let aborted_guard = self self.handle.with_read_aborted(|aborted| aborted)
.handle
.aborted
.read()
.expect("Threadsafe Function aborted lock failed");
*aborted_guard
} }
pub fn abort(self) -> Result<()> { pub fn abort(self) -> Result<()> {
let mut aborted_guard = self self.handle.with_write_aborted(|mut aborted_guard| {
.handle if !*aborted_guard {
.aborted check_status!(unsafe {
.write() sys::napi_release_threadsafe_function(
.expect("Threadsafe Function aborted lock failed"); self.handle.get_raw(),
if !*aborted_guard { sys::ThreadsafeFunctionReleaseMode::abort,
check_status!(unsafe { )
sys::napi_release_threadsafe_function( })?;
self.handle.get_raw(), *aborted_guard = true;
sys::ThreadsafeFunctionReleaseMode::abort, }
) Ok(())
})?; })
*aborted_guard = true;
}
Ok(())
} }
/// Get the raw `ThreadSafeFunction` pointer /// Get the raw `ThreadSafeFunction` pointer
@ -415,21 +422,27 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
/// for more information. /// for more information.
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
unsafe { self.handle.with_read_aborted(|aborted| {
sys::napi_call_threadsafe_function( if aborted {
self.handle.get_raw(), return Status::Closing;
Box::into_raw(Box::new(value.map(|data| { }
ThreadsafeFunctionCallJsBackData {
data, unsafe {
call_variant: ThreadsafeFunctionCallVariant::Direct, sys::napi_call_threadsafe_function(
callback: Box::new(|_d: JsUnknown| Ok(())), self.handle.get_raw(),
} Box::into_raw(Box::new(value.map(|data| {
}))) ThreadsafeFunctionCallJsBackData {
.cast(), data,
mode.into(), call_variant: ThreadsafeFunctionCallVariant::Direct,
) callback: Box::new(|_d: JsUnknown| Ok(())),
} }
.into() })))
.cast(),
mode.into(),
)
}
.into()
})
} }
pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>( pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>(
@ -438,47 +451,60 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
mode: ThreadsafeFunctionCallMode, mode: ThreadsafeFunctionCallMode,
cb: F, cb: F,
) -> Status { ) -> Status {
unsafe { self.handle.with_read_aborted(|aborted| {
sys::napi_call_threadsafe_function( if aborted {
self.handle.get_raw(), return Status::Closing;
Box::into_raw(Box::new(value.map(|data| { }
ThreadsafeFunctionCallJsBackData {
data, unsafe {
call_variant: ThreadsafeFunctionCallVariant::WithCallback, sys::napi_call_threadsafe_function(
callback: Box::new(move |d: JsUnknown| { self.handle.get_raw(),
D::from_napi_value(d.0.env, d.0.value).and_then(cb) Box::into_raw(Box::new(value.map(|data| {
}), ThreadsafeFunctionCallJsBackData {
} data,
}))) call_variant: ThreadsafeFunctionCallVariant::WithCallback,
.cast(), callback: Box::new(move |d: JsUnknown| {
mode.into(), D::from_napi_value(d.0.env, d.0.value).and_then(cb)
) }),
} }
.into() })))
.cast(),
mode.into(),
)
}
.into()
})
} }
#[cfg(feature = "tokio_rt")] #[cfg(feature = "tokio_rt")]
pub async fn call_async<D: 'static + FromNapiValue>(&self, value: Result<T>) -> Result<D> { pub async fn call_async<D: 'static + FromNapiValue>(&self, value: Result<T>) -> Result<D> {
let (sender, receiver) = tokio::sync::oneshot::channel::<D>(); let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
check_status!(unsafe {
sys::napi_call_threadsafe_function( self.handle.with_read_aborted(|aborted| {
self.handle.get_raw(), if aborted {
Box::into_raw(Box::new(value.map(|data| { return Err(crate::Error::from_status(Status::Closing));
ThreadsafeFunctionCallJsBackData { }
data,
call_variant: ThreadsafeFunctionCallVariant::WithCallback, check_status!(unsafe {
callback: Box::new(move |d: JsUnknown| { sys::napi_call_threadsafe_function(
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { self.handle.get_raw(),
sender.send(d).map_err(|_| { Box::into_raw(Box::new(value.map(|data| {
crate::Error::from_reason("Failed to send return value to tokio sender") ThreadsafeFunctionCallJsBackData {
data,
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
callback: Box::new(move |d: JsUnknown| {
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
sender.send(d).map_err(|_| {
crate::Error::from_reason("Failed to send return value to tokio sender")
})
}) })
}) }),
}), }
} })))
}))) .cast(),
.cast(), ThreadsafeFunctionCallMode::NonBlocking.into(),
ThreadsafeFunctionCallMode::NonBlocking.into(), )
) })
})?; })?;
receiver receiver
.await .await
@ -490,19 +516,25 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function) /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
/// for more information. /// for more information.
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status { pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
unsafe { self.handle.with_read_aborted(|aborted| {
sys::napi_call_threadsafe_function( if aborted {
self.handle.get_raw(), return Status::Closing;
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { }
data: value,
call_variant: ThreadsafeFunctionCallVariant::Direct, unsafe {
callback: Box::new(|_d: JsUnknown| Ok(())), sys::napi_call_threadsafe_function(
})) self.handle.get_raw(),
.cast(), Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
mode.into(), data: value,
) call_variant: ThreadsafeFunctionCallVariant::Direct,
} callback: Box::new(|_d: JsUnknown| Ok(())),
.into() }))
.cast(),
mode.into(),
)
}
.into()
})
} }
pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>( pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>(
@ -511,44 +543,58 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
mode: ThreadsafeFunctionCallMode, mode: ThreadsafeFunctionCallMode,
cb: F, cb: F,
) -> Status { ) -> Status {
unsafe { self.handle.with_read_aborted(|aborted| {
sys::napi_call_threadsafe_function( if aborted {
self.handle.get_raw(), return Status::Closing;
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { }
data: value,
call_variant: ThreadsafeFunctionCallVariant::WithCallback, unsafe {
callback: Box::new(move |d: JsUnknown| { sys::napi_call_threadsafe_function(
D::from_napi_value(d.0.env, d.0.value).and_then(cb) self.handle.get_raw(),
}), Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
})) data: value,
.cast(), call_variant: ThreadsafeFunctionCallVariant::WithCallback,
mode.into(), callback: Box::new(move |d: JsUnknown| {
) D::from_napi_value(d.0.env, d.0.value).and_then(cb)
} }),
.into() }))
.cast(),
mode.into(),
)
}
.into()
})
} }
#[cfg(feature = "tokio_rt")] #[cfg(feature = "tokio_rt")]
pub async fn call_async<D: 'static + FromNapiValue>(&self, value: T) -> Result<D> { pub async fn call_async<D: 'static + FromNapiValue>(&self, value: T) -> Result<D> {
let (sender, receiver) = tokio::sync::oneshot::channel::<D>(); let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
check_status!(unsafe {
sys::napi_call_threadsafe_function( self.handle.with_read_aborted(|aborted| {
self.handle.get_raw(), if aborted {
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData { return Err(crate::Error::from_status(Status::Closing));
data: value, }
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
callback: Box::new(move |d: JsUnknown| { check_status!(unsafe {
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| { sys::napi_call_threadsafe_function(
sender.send(d).map_err(|_| { self.handle.get_raw(),
crate::Error::from_reason("Failed to send return value to tokio sender") Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
data: value,
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
callback: Box::new(move |d: JsUnknown| {
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
sender.send(d).map_err(|_| {
crate::Error::from_reason("Failed to send return value to tokio sender")
})
}) })
}) }),
}), }))
})) .cast(),
.cast(), ThreadsafeFunctionCallMode::NonBlocking.into(),
ThreadsafeFunctionCallMode::NonBlocking.into(), )
) })
})?; })?;
receiver receiver
.await .await
.map_err(|err| crate::Error::new(Status::GenericFailure, format!("{}", err))) .map_err(|err| crate::Error::new(Status::GenericFailure, format!("{}", err)))
@ -567,14 +613,11 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
unsafe { Weak::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()).upgrade() }; unsafe { Weak::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()).upgrade() };
if let Some(handle) = handle_option { if let Some(handle) = handle_option {
let mut aborted_guard = handle handle.with_write_aborted(|mut aborted_guard| {
.aborted if !*aborted_guard {
.write() *aborted_guard = true;
.expect("Threadsafe Function Handle aborted lock failed"); }
});
if !*aborted_guard {
*aborted_guard = true;
}
} }
// cleanup // cleanup