fix(napi): use weak arc for passing thread_finalize_data (#1525)

* fix(napi): use weak arc for passing thread_finalize_data

* fix: try to fix test of tsfn_return_promise_timeout
This commit is contained in:
Bo 2023-03-20 11:56:54 +08:00 committed by GitHub
parent 347e81b3cc
commit a6e1ff471c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 19 deletions

View file

@ -4,10 +4,9 @@ use std::convert::Into;
use std::ffi::CString; use std::ffi::CString;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::c_void; use std::os::raw::c_void;
use std::pin::Pin;
use std::ptr::{self, null_mut}; use std::ptr::{self, null_mut};
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock, Weak};
use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue}; use crate::bindgen_runtime::{FromNapiValue, ToNapiValue, TypeName, ValidateNapiValue};
use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status}; use crate::{check_status, sys, Env, JsError, JsUnknown, Result, Status};
@ -103,20 +102,17 @@ struct ThreadsafeFunctionHandle {
referred: AtomicBool, referred: AtomicBool,
} }
unsafe impl Send for ThreadsafeFunctionHandle {}
unsafe impl Sync for ThreadsafeFunctionHandle {}
impl ThreadsafeFunctionHandle { impl ThreadsafeFunctionHandle {
/// create a pinned Arc to hold the `ThreadsafeFunctionHandle` /// create a Arc to hold the `ThreadsafeFunctionHandle`
fn new(raw: sys::napi_threadsafe_function) -> Pin<Arc<Self>> { fn new(raw: sys::napi_threadsafe_function) -> Arc<Self> {
Arc::pin(Self { Arc::new(Self {
raw: AtomicPtr::new(raw), raw: AtomicPtr::new(raw),
aborted: RwLock::new(false), aborted: RwLock::new(false),
referred: AtomicBool::new(true), referred: AtomicBool::new(true),
}) })
} }
fn null() -> Pin<Arc<Self>> { fn null() -> Arc<Self> {
Self::new(null_mut()) Self::new(null_mut())
} }
@ -215,7 +211,7 @@ struct ThreadsafeFunctionCallJsBackData<T> {
/// } /// }
/// ``` /// ```
pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> { pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
handle: Pin<Arc<ThreadsafeFunctionHandle>>, handle: Arc<ThreadsafeFunctionHandle>,
_phantom: PhantomData<(T, ES)>, _phantom: PhantomData<(T, ES)>,
} }
@ -337,7 +333,7 @@ impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
async_resource_name, async_resource_name,
max_queue_size, max_queue_size,
1, 1,
Arc::into_raw(Pin::into_inner(handle.clone())) as *mut c_void, // pass handler to thread_finalize_cb Arc::downgrade(&handle).into_raw() as *mut c_void, // pass handler to thread_finalize_cb
Some(thread_finalize_cb::<T, V, R>), Some(thread_finalize_cb::<T, V, R>),
callback_ptr.cast(), callback_ptr.cast(),
Some(call_js_cb::<T, V, R, ES>), Some(call_js_cb::<T, V, R, ES>),
@ -569,13 +565,18 @@ unsafe extern "C" fn thread_finalize_cb<T: 'static, V: ToNapiValue, R>(
) where ) where
R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>, R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
{ {
let handle = unsafe { Arc::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()) }; let handle_option =
let mut aborted_guard = handle unsafe { Weak::from_raw(finalize_data.cast::<ThreadsafeFunctionHandle>()).upgrade() };
.aborted
.write() if let Some(handle) = handle_option {
.expect("Threadsafe Function Handle aborted lock failed"); let mut aborted_guard = handle
if !*aborted_guard { .aborted
*aborted_guard = true; .write()
.expect("Threadsafe Function Handle aborted lock failed");
if !*aborted_guard {
*aborted_guard = true;
}
} }
// cleanup // cleanup

View file

@ -137,7 +137,7 @@ pub async fn tsfn_return_promise(func: ThreadsafeFunction<u32>) -> Result<u32> {
pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction<u32>) -> Result<u32> { pub async fn tsfn_return_promise_timeout(func: ThreadsafeFunction<u32>) -> Result<u32> {
use tokio::time::{self, Duration}; use tokio::time::{self, Duration};
let promise = func.call_async::<Promise<u32>>(Ok(1)).await?; let promise = func.call_async::<Promise<u32>>(Ok(1)).await?;
let sleep = time::sleep(Duration::from_millis(200)); let sleep = time::sleep(Duration::from_millis(100));
tokio::select! { tokio::select! {
_ = sleep => { _ = sleep => {
return Err(Error::new(Status::GenericFailure, "Timeout".to_owned())); return Err(Error::new(Status::GenericFailure, "Timeout".to_owned()));