diff --git a/napi/src/env.rs b/napi/src/env.rs index e5894997..ecc3713a 100644 --- a/napi/src/env.rs +++ b/napi/src/env.rs @@ -673,6 +673,32 @@ impl Env { } } + #[inline] + pub fn unwrap_from_ref(&self, js_ref: &Ref<()>) -> Result<&'static mut T> { + unsafe { + let mut unknown_tagged_object: *mut c_void = ptr::null_mut(); + check_status!(sys::napi_unwrap( + self.0, + js_ref.raw_value, + &mut unknown_tagged_object, + ))?; + + let type_id = unknown_tagged_object as *const TypeId; + if *type_id == TypeId::of::() { + let tagged_object = unknown_tagged_object as *mut TaggedObject; + (*tagged_object).object.as_mut().ok_or(Error { + status: Status::InvalidArg, + reason: "Invalid argument, nothing attach to js_object".to_owned(), + }) + } else { + Err(Error { + status: Status::InvalidArg, + reason: "Invalid argument, T on unrwap is not the type of wrapped object".to_owned(), + }) + } + } + } + #[inline] pub fn drop_wrapped(&self, js_object: JsObject) -> Result<()> { unsafe { @@ -706,13 +732,15 @@ impl Env { { let mut raw_ref = ptr::null_mut(); let initial_ref_count = 1; + let raw_value = unsafe { value.raw() }; check_status!(unsafe { - sys::napi_create_reference(self.0, value.raw(), initial_ref_count, &mut raw_ref) + sys::napi_create_reference(self.0, raw_value, initial_ref_count, &mut raw_ref) })?; Ok(Ref { raw_ref, count: 1, inner: (), + raw_value, }) } diff --git a/napi/src/js_values/value_ref.rs b/napi/src/js_values/value_ref.rs index 908eaa43..1e22c090 100644 --- a/napi/src/js_values/value_ref.rs +++ b/napi/src/js_values/value_ref.rs @@ -8,6 +8,7 @@ pub struct Ref { pub(crate) raw_ref: sys::napi_ref, pub(crate) count: u32, pub(crate) inner: T, + pub(crate) raw_value: sys::napi_value, } unsafe impl Send for Ref {} @@ -30,6 +31,7 @@ impl Ref { raw_ref, count: ref_count, inner, + raw_value: js_value.value, }) } diff --git a/napi/src/threadsafe_function.rs b/napi/src/threadsafe_function.rs index f60dd8f4..61ce0c2e 100644 --- a/napi/src/threadsafe_function.rs +++ b/napi/src/threadsafe_function.rs @@ -68,7 +68,7 @@ impl Into for ThreadsafeFunctionCallMode { /// .collect::>>() /// })?; /// -/// let tsfn_cloned = tsfn.try_clone()?; +/// let tsfn_cloned = tsfn.clone(); /// /// thread::spawn(move || { /// let output: Vec = vec![0, 1, 2, 3]; @@ -91,6 +91,24 @@ pub struct ThreadsafeFunction { _phantom: PhantomData, } +impl Clone for ThreadsafeFunction { + fn clone(&self) -> Self { + if !self.aborted.load(Ordering::Acquire) { + let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) }; + debug_assert!( + acquire_status == sys::Status::napi_ok, + "Acquire threadsafe function failed in clone" + ); + } + + Self { + raw_tsfn: self.raw_tsfn, + aborted: Arc::clone(&self.aborted), + _phantom: PhantomData, + } + } +} + unsafe impl Send for ThreadsafeFunction {} unsafe impl Sync for ThreadsafeFunction {} @@ -203,21 +221,6 @@ impl ThreadsafeFunction { Ok(()) } - pub fn try_clone(&self) -> Result { - if self.aborted.load(Ordering::Acquire) { - return Err(Error::new( - Status::Closing, - format!("Can not clone, Thread safe function already aborted"), - )); - } - check_status!(unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) })?; - Ok(Self { - raw_tsfn: self.raw_tsfn, - aborted: Arc::clone(&self.aborted), - _phantom: PhantomData, - }) - } - /// Get the raw `ThreadSafeFunction` pointer pub fn raw(&self) -> sys::napi_threadsafe_function { self.raw_tsfn diff --git a/test_module/src/napi4/tsfn.rs b/test_module/src/napi4/tsfn.rs index 1f5cfd2b..9358fa89 100644 --- a/test_module/src/napi4/tsfn.rs +++ b/test_module/src/napi4/tsfn.rs @@ -23,7 +23,7 @@ pub fn test_threadsafe_function(ctx: CallContext) -> Result { .collect::>>() })?; - let tsfn_cloned = tsfn.try_clone()?; + let tsfn_cloned = tsfn.clone(); thread::spawn(move || { let output: Vec = vec![0, 1, 2, 3]; @@ -55,7 +55,7 @@ pub fn test_abort_threadsafe_function(ctx: CallContext) -> Result { .collect::>>() })?; - let tsfn_cloned = tsfn.try_clone()?; + let tsfn_cloned = tsfn.clone(); tsfn_cloned.abort()?; ctx.env.get_boolean(tsfn.aborted()) @@ -92,7 +92,7 @@ pub fn test_call_aborted_threadsafe_function(ctx: CallContext) -> Result