This commit is contained in:
parent
c1072462a5
commit
0af728a601
1 changed files with 57 additions and 21 deletions
|
@ -4,8 +4,9 @@ use std::convert::Into;
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::os::raw::c_void;
|
use std::os::raw::c_void;
|
||||||
use std::ptr;
|
use std::pin::Pin;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::ptr::{self, null_mut};
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
|
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
|
||||||
|
@ -97,7 +98,7 @@ type_level_enum! {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ThreadsafeFunctionHandle {
|
struct ThreadsafeFunctionHandle {
|
||||||
raw: sys::napi_threadsafe_function,
|
raw: AtomicPtr<sys::napi_threadsafe_function__>,
|
||||||
aborted: RwLock<bool>,
|
aborted: RwLock<bool>,
|
||||||
referred: AtomicBool,
|
referred: AtomicBool,
|
||||||
}
|
}
|
||||||
|
@ -105,6 +106,31 @@ struct ThreadsafeFunctionHandle {
|
||||||
unsafe impl Send for ThreadsafeFunctionHandle {}
|
unsafe impl Send for ThreadsafeFunctionHandle {}
|
||||||
unsafe impl Sync for ThreadsafeFunctionHandle {}
|
unsafe impl Sync for ThreadsafeFunctionHandle {}
|
||||||
|
|
||||||
|
impl ThreadsafeFunctionHandle {
|
||||||
|
/// create a pinned Arc to hold the `ThreadsafeFunctionHandle`
|
||||||
|
fn new(raw: sys::napi_threadsafe_function) -> Pin<Arc<Self>> {
|
||||||
|
Arc::pin(Self {
|
||||||
|
raw: AtomicPtr::new(raw),
|
||||||
|
aborted: RwLock::new(false),
|
||||||
|
referred: AtomicBool::new(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn null() -> Pin<Arc<Self>> {
|
||||||
|
Self::new(null_mut())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_raw(&self) -> sys::napi_threadsafe_function {
|
||||||
|
let raw = self.raw.load(Ordering::SeqCst);
|
||||||
|
assert!(!raw.is_null());
|
||||||
|
raw
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_raw(&self, raw: sys::napi_threadsafe_function) {
|
||||||
|
self.raw.store(raw, Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Drop for ThreadsafeFunctionHandle {
|
impl Drop for ThreadsafeFunctionHandle {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let aborted_guard = self
|
let aborted_guard = self
|
||||||
|
@ -113,7 +139,10 @@ impl Drop for ThreadsafeFunctionHandle {
|
||||||
.expect("Threadsafe Function aborted lock failed");
|
.expect("Threadsafe Function aborted lock failed");
|
||||||
if !*aborted_guard && self.referred.load(Ordering::Acquire) {
|
if !*aborted_guard && self.referred.load(Ordering::Acquire) {
|
||||||
let release_status = unsafe {
|
let release_status = unsafe {
|
||||||
sys::napi_release_threadsafe_function(self.raw, sys::ThreadsafeFunctionReleaseMode::release)
|
sys::napi_release_threadsafe_function(
|
||||||
|
self.get_raw(),
|
||||||
|
sys::ThreadsafeFunctionReleaseMode::release,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
assert!(
|
assert!(
|
||||||
release_status == sys::Status::napi_ok,
|
release_status == sys::Status::napi_ok,
|
||||||
|
@ -186,7 +215,7 @@ struct ThreadsafeFunctionCallJsBackData<T> {
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
|
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
|
||||||
handle: Arc<ThreadsafeFunctionHandle>,
|
handle: Pin<Arc<ThreadsafeFunctionHandle>>,
|
||||||
_phantom: PhantomData<(T, ES)>,
|
_phantom: PhantomData<(T, ES)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,6 +328,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
|
|
||||||
let mut raw_tsfn = ptr::null_mut();
|
let mut raw_tsfn = ptr::null_mut();
|
||||||
let callback_ptr = Box::into_raw(Box::new(callback));
|
let callback_ptr = Box::into_raw(Box::new(callback));
|
||||||
|
let handle = ThreadsafeFunctionHandle::null();
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_create_threadsafe_function(
|
sys::napi_create_threadsafe_function(
|
||||||
env,
|
env,
|
||||||
|
@ -307,20 +337,17 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
async_resource_name,
|
async_resource_name,
|
||||||
max_queue_size,
|
max_queue_size,
|
||||||
1,
|
1,
|
||||||
ptr::null_mut(),
|
Arc::into_raw(Pin::into_inner(handle.clone())) as *mut c_void, // pass handler to thread_finalize_cb
|
||||||
Some(thread_finalize_cb::<T, V, R>),
|
Some(thread_finalize_cb::<T, V, R>),
|
||||||
callback_ptr.cast(),
|
callback_ptr.cast(),
|
||||||
Some(call_js_cb::<T, V, R, ES>),
|
Some(call_js_cb::<T, V, R, ES>),
|
||||||
&mut raw_tsfn,
|
&mut raw_tsfn,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
handle.set_raw(raw_tsfn);
|
||||||
|
|
||||||
Ok(ThreadsafeFunction {
|
Ok(ThreadsafeFunction {
|
||||||
handle: Arc::new(ThreadsafeFunctionHandle {
|
handle,
|
||||||
raw: raw_tsfn,
|
|
||||||
aborted: RwLock::new(false),
|
|
||||||
referred: AtomicBool::new(true),
|
|
||||||
}),
|
|
||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -336,7 +363,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
.read()
|
.read()
|
||||||
.expect("Threadsafe Function aborted lock failed");
|
.expect("Threadsafe Function aborted lock failed");
|
||||||
if !*aborted_guard && !self.handle.referred.load(Ordering::Acquire) {
|
if !*aborted_guard && !self.handle.referred.load(Ordering::Acquire) {
|
||||||
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.raw) })?;
|
check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.handle.get_raw()) })?;
|
||||||
self.handle.referred.store(true, Ordering::Release);
|
self.handle.referred.store(true, Ordering::Release);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -351,7 +378,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
.read()
|
.read()
|
||||||
.expect("Threadsafe Function aborted lock failed");
|
.expect("Threadsafe Function aborted lock failed");
|
||||||
if !*aborted_guard && self.handle.referred.load(Ordering::Acquire) {
|
if !*aborted_guard && self.handle.referred.load(Ordering::Acquire) {
|
||||||
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.raw) })?;
|
check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.handle.get_raw()) })?;
|
||||||
self.handle.referred.store(false, Ordering::Release);
|
self.handle.referred.store(false, Ordering::Release);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -375,7 +402,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
if !*aborted_guard {
|
if !*aborted_guard {
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_release_threadsafe_function(
|
sys::napi_release_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
sys::ThreadsafeFunctionReleaseMode::abort,
|
sys::ThreadsafeFunctionReleaseMode::abort,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
@ -386,7 +413,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
|
||||||
|
|
||||||
/// Get the raw `ThreadSafeFunction` pointer
|
/// Get the raw `ThreadSafeFunction` pointer
|
||||||
pub fn raw(&self) -> sys::napi_threadsafe_function {
|
pub fn raw(&self) -> sys::napi_threadsafe_function {
|
||||||
self.handle.raw
|
self.handle.get_raw()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -396,7 +423,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
||||||
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
|
pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(value.map(|data| {
|
Box::into_raw(Box::new(value.map(|data| {
|
||||||
ThreadsafeFunctionCallJsBackData {
|
ThreadsafeFunctionCallJsBackData {
|
||||||
data,
|
data,
|
||||||
|
@ -419,7 +446,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
||||||
) -> Status {
|
) -> Status {
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(value.map(|data| {
|
Box::into_raw(Box::new(value.map(|data| {
|
||||||
ThreadsafeFunctionCallJsBackData {
|
ThreadsafeFunctionCallJsBackData {
|
||||||
data,
|
data,
|
||||||
|
@ -441,7 +468,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
|
||||||
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(value.map(|data| {
|
Box::into_raw(Box::new(value.map(|data| {
|
||||||
ThreadsafeFunctionCallJsBackData {
|
ThreadsafeFunctionCallJsBackData {
|
||||||
data,
|
data,
|
||||||
|
@ -471,7 +498,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
||||||
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
|
pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||||
data: value,
|
data: value,
|
||||||
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
call_variant: ThreadsafeFunctionCallVariant::Direct,
|
||||||
|
@ -492,7 +519,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
||||||
) -> Status {
|
) -> Status {
|
||||||
unsafe {
|
unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||||
data: value,
|
data: value,
|
||||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||||
|
@ -512,7 +539,7 @@ impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
|
||||||
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
let (sender, receiver) = tokio::sync::oneshot::channel::<D>();
|
||||||
check_status!(unsafe {
|
check_status!(unsafe {
|
||||||
sys::napi_call_threadsafe_function(
|
sys::napi_call_threadsafe_function(
|
||||||
self.handle.raw,
|
self.handle.get_raw(),
|
||||||
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
Box::into_raw(Box::new(ThreadsafeFunctionCallJsBackData {
|
||||||
data: value,
|
data: value,
|
||||||
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
call_variant: ThreadsafeFunctionCallVariant::WithCallback,
|
||||||
|
@ -542,6 +569,15 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
|
||||||
) where
|
) where
|
||||||
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
|
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
|
||||||
{
|
{
|
||||||
|
let handle = unsafe { Arc::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()) };
|
||||||
|
let mut aborted_guard = handle
|
||||||
|
.aborted
|
||||||
|
.write()
|
||||||
|
.expect("Threadsafe Function Handle aborted lock failed");
|
||||||
|
if !*aborted_guard {
|
||||||
|
*aborted_guard = true;
|
||||||
|
}
|
||||||
|
|
||||||
// cleanup
|
// cleanup
|
||||||
drop(unsafe { Box::<R>::from_raw(finalize_hint.cast()) });
|
drop(unsafe { Box::<R>::from_raw(finalize_hint.cast()) });
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue