diff --git a/napi/src/env.rs b/napi/src/env.rs index d73cab24..4bff1702 100644 --- a/napi/src/env.rs +++ b/napi/src/env.rs @@ -637,6 +637,53 @@ impl Env { } } + #[inline] + /// This API create a new reference with the specified reference count to the Object passed in. + pub fn create_reference(&self, value: T) -> Result> + where + T: NapiValue, + { + let mut raw_ref = ptr::null_mut(); + let initial_ref_count = 1; + check_status!(unsafe { + sys::napi_create_reference(self.0, value.raw(), initial_ref_count, &mut raw_ref) + })?; + Ok(Ref { + raw_ref, + count: 1, + inner: (), + }) + } + + #[inline] + /// Get reference value from `Ref` with type check + /// Return error if the type of `reference` provided is mismatched with `T` + pub fn get_reference_value(&self, reference: &Ref<()>) -> Result + where + T: NapiValue, + { + let mut js_value = ptr::null_mut(); + check_status!(unsafe { + sys::napi_get_reference_value(self.0, reference.raw_ref, &mut js_value) + })?; + unsafe { T::from_raw(self.0, js_value) } + } + + #[inline] + /// Get reference value from `Ref` without type check + /// Using this API if you are sure the type of `T` is matched with provided `Ref<()>`. + /// If type mismatched, calling `T::method` would return `Err`. + pub fn get_reference_value_unchecked(&self, reference: &Ref<()>) -> Result + where + T: NapiValue, + { + let mut js_value = ptr::null_mut(); + check_status!(unsafe { + sys::napi_get_reference_value(self.0, reference.raw_ref, &mut js_value) + })?; + Ok(unsafe { T::from_raw_unchecked(self.0, js_value) }) + } + #[inline] pub fn create_external(&self, native_object: T) -> Result { let mut object_value = ptr::null_mut(); diff --git a/napi/src/js_values/mod.rs b/napi/src/js_values/mod.rs index a2a9c76e..c2eb7d63 100644 --- a/napi/src/js_values/mod.rs +++ b/napi/src/js_values/mod.rs @@ -55,7 +55,7 @@ pub use string::*; pub(crate) use tagged_object::TaggedObject; pub use undefined::JsUndefined; pub(crate) use value::Value; -pub use value_ref::Ref; +pub use value_ref::*; pub use value_type::ValueType; // Value types diff --git a/napi/src/js_values/value_ref.rs b/napi/src/js_values/value_ref.rs index 90bb4c24..908eaa43 100644 --- a/napi/src/js_values/value_ref.rs +++ b/napi/src/js_values/value_ref.rs @@ -6,8 +6,8 @@ use crate::{sys, Env, Result}; pub struct Ref { pub(crate) raw_ref: sys::napi_ref, - count: u32, - inner: T, + pub(crate) count: u32, + pub(crate) inner: T, } unsafe impl Send for Ref {} diff --git a/test_module/__test__/napi4/threadsafe_function.spec.ts b/test_module/__test__/napi4/threadsafe_function.spec.ts index 6f9c2858..9bb067a4 100644 --- a/test_module/__test__/napi4/threadsafe_function.spec.ts +++ b/test_module/__test__/napi4/threadsafe_function.spec.ts @@ -12,7 +12,7 @@ test('should get js function called from a thread', async (t) => { return } - await new Promise((resolve, reject) => { + await new Promise((resolve, reject) => { bindings.testThreadsafeFunction((...args: any[]) => { called += 1 try { @@ -55,3 +55,20 @@ test('should return Closing while calling aborted tsfn', (t) => { } t.notThrows(() => bindings.testCallAbortedThreadsafeFunction(() => {})) }) + +test('should work with napi ref', (t) => { + if (napiVersion < 4) { + t.is(bindings.testTsfnWithRef, undefined) + } else { + const obj = { + foo: Symbol(), + } + return new Promise((resolve) => { + bindings.testTsfnWithRef((err: Error | null, returnObj: any) => { + t.is(err, null) + t.is(obj, returnObj) + resolve() + }, obj) + }) + } +}) diff --git a/test_module/src/napi4/mod.rs b/test_module/src/napi4/mod.rs index ea79413e..ab4b989d 100644 --- a/test_module/src/napi4/mod.rs +++ b/test_module/src/napi4/mod.rs @@ -20,5 +20,6 @@ pub fn register_js(exports: &mut JsObject) -> Result<()> { "testCallAbortedThreadsafeFunction", test_call_aborted_threadsafe_function, )?; + exports.create_named_method("testTsfnWithRef", test_tsfn_with_ref)?; Ok(()) } diff --git a/test_module/src/napi4/tsfn.rs b/test_module/src/napi4/tsfn.rs index 7097157d..1f5cfd2b 100644 --- a/test_module/src/napi4/tsfn.rs +++ b/test_module/src/napi4/tsfn.rs @@ -3,7 +3,8 @@ use std::thread; use napi::{ threadsafe_function::{ThreadSafeCallContext, ThreadsafeFunctionCallMode}, - CallContext, Error, JsBoolean, JsFunction, JsNumber, JsString, JsUndefined, Result, Status, + CallContext, Error, JsBoolean, JsFunction, JsNumber, JsObject, JsString, JsUndefined, Ref, + Result, Status, }; use tokio; @@ -148,3 +149,25 @@ pub fn test_tokio_readfile(ctx: CallContext) -> Result { ctx.env.get_undefined() } + +#[js_function(2)] +pub fn test_tsfn_with_ref(ctx: CallContext) -> Result { + let callback = ctx.get::(0)?; + let options = ctx.get::(1)?; + let options_ref = ctx.env.create_reference(options)?; + let tsfn = + ctx + .env + .create_threadsafe_function(&callback, 0, |ctx: ThreadSafeCallContext>| { + ctx + .env + .get_reference_value_unchecked::(&ctx.value) + .and_then(|obj| ctx.value.unref(ctx.env).map(|_| vec![obj])) + })?; + + thread::spawn(move || { + tsfn.call(Ok(options_ref), ThreadsafeFunctionCallMode::Blocking); + }); + + ctx.env.get_undefined() +}