fix(napi): ensure that napi_call_threadsafe_function
cannot be called after abort (#1533)
* refactor(napi): reduce boilerplate code for accessing `aborted` lock * refactor: ensure that `napi_call_threadsafe_function` cannot be called after abort --------- Co-authored-by: LongYinan <lynweklm@gmail.com>
This commit is contained in:
parent
e47c13f177
commit
d8cfcfdfda
1 changed files with 216 additions and 173 deletions
|
@ -6,7 +6,7 @@ use std::marker::PhantomData;
|
|||
use std::os::raw::c_void;
|
||||
use std::ptr::{self, null_mut};
|
||||
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
|
||||
use std::sync::{Arc, RwLock, Weak};
|
||||
use std::sync::{Arc, RwLock, RwLockWriteGuard, Weak};
|
||||
|
||||
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
|
||||
use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status};
|
||||
|
@ -112,6 +112,30 @@ impl ThreadsafeFunctionHandle {
|
|||
})
|
||||
}
|
||||
|
||||
/// Lock `aborted` with read access, call `f` with the value of `aborted`, then unlock it
|
||||
fn with_read_aborted<RT, F>(&self, f: F) -> RT
|
||||
where
|
||||
F: FnOnce(bool) -> RT,
|
||||
{
|
||||
let aborted_guard = self
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
f(*aborted_guard)
|
||||
}
|
||||
|
||||
/// Lock `aborted` with write access, call `f` with the `RwLockWriteGuard`, then unlock it
|
||||
fn with_write_aborted<RT, F>(&self, f: F) -> RT
|
||||
where
|
||||
F: FnOnce(RwLockWriteGuard<bool>) -> RT,
|
||||
{
|
||||
let aborted_guard = self
|
||||
.aborted
|
||||
.write()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
f(aborted_guard)
|
||||
}
|
||||
|
||||
fn null() -> Arc<Self> {
|
||||
Self::new(null_mut())
|
||||
}
|
||||
|
@ -127,23 +151,21 @@ impl ThreadsafeFunctionHandle {
|
|||
|
||||
impl Drop for ThreadsafeFunctionHandle {
|
||||
fn drop(&mut self) {
|
||||
let aborted_guard = self
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
if !*aborted_guard {
|
||||
let release_status = unsafe {
|
||||
sys::napi_release_threadsafe_function(
|
||||
self.get_raw(),
|
||||
sys::ThreadsafeFunctionReleaseMode::release,
|
||||
)
|
||||
};
|
||||
assert!(
|
||||
release_status == sys::Status::napi_ok,
|
||||
"Threadsafe Function release failed {}",
|
||||
Status::from(release_status)
|
||||
);
|
||||
}
|
||||
self.with_read_aborted(|aborted| {
|
||||
if !aborted {
|
||||
let release_status = unsafe {
|
||||
sys::napi_release_threadsafe_function(
|
||||
self.get_raw(),
|
||||
sys::ThreadsafeFunctionReleaseMode::release,
|
||||
)
|
||||
};
|
||||
assert!(
|
||||
release_status == sys::Status::napi_ok,
|
||||
"Threadsafe Function release failed {}",
|
||||
Status::from(release_status)
|
||||
);
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,19 +237,16 @@ pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::
|
|||
|
||||
impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
|
||||
fn clone(&self) -> Self {
|
||||
let aborted_guard = self
|
||||
.handle
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
if *aborted_guard {
|
||||
panic!("ThreadsafeFunction was aborted, can not clone it");
|
||||
}
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
panic!("ThreadsafeFunction was aborted, can not clone it");
|
||||
};
|
||||
|
||||
Self {
|
||||
handle: self.handle.clone(),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
Self {
|
||||
handle: self.handle.clone(),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -351,58 +370,46 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
|||
///
|
||||
/// "ref" is a keyword so that we use "refer" here.
|
||||
pub fn refer(&mut self, env: &Env) -> Result<()> {
|
||||
let aborted_guard = self
|
||||
.handle
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
if !*aborted_guard && !self.handle.referred.load(Ordering::Relaxed) {
|
||||
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?;
|
||||
self.handle.referred.store(true, Ordering::Relaxed);
|
||||
}
|
||||
Ok(())
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if !aborted && !self.handle.referred.load(Ordering::Relaxed) {
|
||||
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?;
|
||||
self.handle.referred.store(true, Ordering::Relaxed);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
|
||||
/// for more information.
|
||||
pub fn unref(&mut self, env: &Env) -> Result<()> {
|
||||
let aborted_guard = self
|
||||
.handle
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
if !*aborted_guard && self.handle.referred.load(Ordering::Relaxed) {
|
||||
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?;
|
||||
self.handle.referred.store(false, Ordering::Relaxed);
|
||||
}
|
||||
Ok(())
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if !aborted && self.handle.referred.load(Ordering::Relaxed) {
|
||||
check_status!(unsafe {
|
||||
sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw())
|
||||
})?;
|
||||
self.handle.referred.store(false, Ordering::Relaxed);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn aborted(&self) -> bool {
|
||||
let aborted_guard = self
|
||||
.handle
|
||||
.aborted
|
||||
.read()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
*aborted_guard
|
||||
self.handle.with_read_aborted(|aborted| aborted)
|
||||
}
|
||||
|
||||
pub fn abort(self) -> Result<()> {
|
||||
let mut aborted_guard = self
|
||||
.handle
|
||||
.aborted
|
||||
.write()
|
||||
.expect("Threadsafe Function aborted lock failed");
|
||||
if !*aborted_guard {
|
||||
check_status!(unsafe {
|
||||
sys::napi_release_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
sys::ThreadsafeFunctionReleaseMode::abort,
|
||||
)
|
||||
})?;
|
||||
*aborted_guard = true;
|
||||
}
|
||||
Ok(())
|
||||
self.handle.with_write_aborted(|mut aborted_guard| {
|
||||
if !*aborted_guard {
|
||||
check_status!(unsafe {
|
||||
sys::napi_release_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
sys::ThreadsafeFunctionReleaseMode::abort,
|
||||
)
|
||||
})?;
|
||||
*aborted_guard = true;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the raw `ThreadSafeFunction` pointer
|
||||
|
@ -415,21 +422,27 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
|||
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
||||
/// for more information.
|
||||
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
||||
callback: Box::new(|_d: JsUnknown| Ok(())),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Status::Closing;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
||||
callback: Box::new(|_d: JsUnknown| Ok(())),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>(
|
||||
|
@ -438,47 +451,60 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
|||
mode: ThreadsafeFunctionCallMode,
|
||||
cb: F,
|
||||
) -> Status {
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(cb)
|
||||
}),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Status::Closing;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(cb)
|
||||
}),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio_rt")]
|
||||
pub async fn call_async<D: 'static + FromNapiValue>(&self, value: Result<T>) -> Result<D> {
|
||||
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
||||
check_status!(unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
|
||||
sender.send(d).map_err(|_| {
|
||||
crate::Error::from_reason("Failed to send return value to tokio sender")
|
||||
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Err(crate::Error::from_status(Status::Closing));
|
||||
}
|
||||
|
||||
check_status!(unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(value.map(|data| {
|
||||
ThreadsafeFunctionCallJsBackData {
|
||||
data,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
|
||||
sender.send(d).map_err(|_| {
|
||||
crate::Error::from_reason("Failed to send return value to tokio sender")
|
||||
})
|
||||
})
|
||||
})
|
||||
}),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
ThreadsafeFunctionCallMode::NonBlocking.into(),
|
||||
)
|
||||
}),
|
||||
}
|
||||
})))
|
||||
.cast(),
|
||||
ThreadsafeFunctionCallMode::NonBlocking.into(),
|
||||
)
|
||||
})
|
||||
})?;
|
||||
receiver
|
||||
.await
|
||||
|
@ -490,19 +516,25 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
|||
/// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
|
||||
/// for more information.
|
||||
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
||||
callback: Box::new(|_d: JsUnknown| Ok(())),
|
||||
}))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Status::Closing;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
||||
callback: Box::new(|_d: JsUnknown| Ok(())),
|
||||
}))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn call_with_return_value<D: FromNapiValue, F: 'static + FnOnce(D) -> Result<()>>(
|
||||
|
@ -511,44 +543,58 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
|||
mode: ThreadsafeFunctionCallMode,
|
||||
cb: F,
|
||||
) -> Status {
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(cb)
|
||||
}),
|
||||
}))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Status::Closing;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(cb)
|
||||
}),
|
||||
}))
|
||||
.cast(),
|
||||
mode.into(),
|
||||
)
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio_rt")]
|
||||
pub async fn call_async<D: 'static + FromNapiValue>(&self, value: T) -> Result<D> {
|
||||
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
||||
check_status!(unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
|
||||
sender.send(d).map_err(|_| {
|
||||
crate::Error::from_reason("Failed to send return value to tokio sender")
|
||||
|
||||
self.handle.with_read_aborted(|aborted| {
|
||||
if aborted {
|
||||
return Err(crate::Error::from_status(Status::Closing));
|
||||
}
|
||||
|
||||
check_status!(unsafe {
|
||||
sys::napi_call_threadsafe_function(
|
||||
self.handle.get_raw(),
|
||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||
data: value,
|
||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||
callback: Box::new(move |d: JsUnknown| {
|
||||
D::from_napi_value(d.0.env, d.0.value).and_then(move |d| {
|
||||
sender.send(d).map_err(|_| {
|
||||
crate::Error::from_reason("Failed to send return value to tokio sender")
|
||||
})
|
||||
})
|
||||
})
|
||||
}),
|
||||
}))
|
||||
.cast(),
|
||||
ThreadsafeFunctionCallMode::NonBlocking.into(),
|
||||
)
|
||||
}),
|
||||
}))
|
||||
.cast(),
|
||||
ThreadsafeFunctionCallMode::NonBlocking.into(),
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
receiver
|
||||
.await
|
||||
.map_err(|err| crate::Error::new(Status::GenericFailure, format!("{}", err)))
|
||||
|
@ -567,14 +613,11 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
|
|||
unsafe { Weak::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()).upgrade() };
|
||||
|
||||
if let Some(handle) = handle_option {
|
||||
let mut aborted_guard = handle
|
||||
.aborted
|
||||
.write()
|
||||
.expect("Threadsafe Function Handle aborted lock failed");
|
||||
|
||||
if !*aborted_guard {
|
||||
*aborted_guard = true;
|
||||
}
|
||||
handle.with_write_aborted(|mut aborted_guard| {
|
||||
if !*aborted_guard {
|
||||
*aborted_guard = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// cleanup
|
||||
|
|
Loading…
Reference in a new issue