fix(napi-derive): unsound behavior while using reference and async together

This commit is contained in:
Jacob Kiesel 2022-11-21 09:17:19 -07:00 committed by GitHub
parent 91890456da
commit 618d0f8046
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 160 additions and 28 deletions

View file

@ -27,6 +27,7 @@ pub struct NapiFn {
pub enumerable: bool, pub enumerable: bool,
pub configurable: bool, pub configurable: bool,
pub catch_unwind: bool, pub catch_unwind: bool,
pub unsafe_: bool,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -64,7 +65,7 @@ pub enum FnKind {
Setter, Setter,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum FnSelf { pub enum FnSelf {
Value, Value,
Ref, Ref,

View file

@ -1,9 +1,10 @@
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens; use quote::ToTokens;
use syn::spanned::Spanned;
use crate::{ use crate::{
codegen::{get_intermediate_ident, get_register_ident, js_mod_to_token_stream}, codegen::{get_intermediate_ident, get_register_ident, js_mod_to_token_stream},
BindgenResult, CallbackArg, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens, BindgenResult, CallbackArg, Diagnostic, FnKind, FnSelf, NapiFn, NapiFnArgKind, TryToTokens,
}; };
impl TryToTokens for NapiFn { impl TryToTokens for NapiFn {
@ -12,13 +13,78 @@ impl TryToTokens for NapiFn {
let intermediate_ident = get_intermediate_ident(&name_str); let intermediate_ident = get_intermediate_ident(&name_str);
let args_len = self.args.len(); let args_len = self.args.len();
let (arg_conversions, arg_names) = self.gen_arg_conversions()?; let ArgConversions {
arg_conversions,
args: arg_names,
refs,
mut_ref_spans,
unsafe_,
} = self.gen_arg_conversions()?;
// The JS engine can't properly track mutability in an async context, so refuse to compile
// code that tries to use async and mutability together without `unsafe` mark.
if self.is_async && !mut_ref_spans.is_empty() && !unsafe_ {
return Diagnostic::from_vec(
mut_ref_spans
.into_iter()
.map(|s| Diagnostic::span_error(s, "mutable reference is unsafe with async"))
.collect(),
);
}
if Some(FnSelf::MutRef) == self.fn_self && self.is_async {
return Err(Diagnostic::span_error(
self.name.span(),
"&mut self is incompatible with async napi methods",
));
}
let arg_ref_count = refs.len();
let receiver = self.gen_fn_receiver(); let receiver = self.gen_fn_receiver();
let receiver_ret_name = Ident::new("_ret", Span::call_site()); let receiver_ret_name = Ident::new("_ret", Span::call_site());
let ret = self.gen_fn_return(&receiver_ret_name); let ret = self.gen_fn_return(&receiver_ret_name);
let register = self.gen_fn_register(); let register = self.gen_fn_register();
let attrs = &self.attrs; let attrs = &self.attrs;
let build_ref_container = if self.is_async {
quote! {
struct NapiRefContainer([napi::sys::napi_ref; #arg_ref_count]);
impl NapiRefContainer {
fn drop(self, env: napi::sys::napi_env) {
for r in self.0.into_iter() {
assert_eq!(
unsafe { napi::sys::napi_delete_reference(env, r) },
napi::sys::Status::napi_ok,
"failed to delete napi ref"
);
}
}
}
unsafe impl Send for NapiRefContainer {}
unsafe impl Sync for NapiRefContainer {}
let _make_ref = |a: ::std::ptr::NonNull<napi::bindgen_prelude::sys::napi_value__>| {
let mut node_ref = ::std::mem::MaybeUninit::uninit();
assert_eq!(unsafe {
napi::bindgen_prelude::sys::napi_create_reference(env, a.as_ptr(), 1, node_ref.as_mut_ptr())
},
napi::bindgen_prelude::sys::Status::napi_ok,
"failed to create napi ref"
);
unsafe { node_ref.assume_init() }
};
let mut _args_array = [::std::ptr::null_mut::<napi::bindgen_prelude::sys::napi_ref__>(); #arg_ref_count];
let mut _arg_write_index = 0;
#(#refs)*
#[cfg(debug_assert)]
{
for a in &_args_array {
assert!(!a.is_null(), "failed to initialize napi ref");
}
}
let _args_ref = NapiRefContainer(_args_array);
}
} else {
quote! {}
};
let native_call = if !self.is_async { let native_call = if !self.is_async {
quote! { quote! {
napi::bindgen_prelude::within_runtime_if_available(move || { napi::bindgen_prelude::within_runtime_if_available(move || {
@ -35,16 +101,26 @@ impl TryToTokens for NapiFn {
quote! { Ok(#receiver(#(#arg_names),*).await) } quote! { Ok(#receiver(#(#arg_names),*).await) }
}; };
quote! { quote! {
napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, |env, #receiver_ret_name| { napi::bindgen_prelude::execute_tokio_future(env, async move { #call }, move |env, #receiver_ret_name| {
_args_ref.drop(env);
#ret #ret
}) })
} }
}; };
let function_call_inner = quote! {
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#build_ref_container
#(#arg_conversions)*
#native_call
})
};
let function_call = if args_len == 0 let function_call = if args_len == 0
&& self.fn_self.is_none() && self.fn_self.is_none()
&& self.kind != FnKind::Constructor && self.kind != FnKind::Constructor
&& self.kind != FnKind::Factory && self.kind != FnKind::Factory
&& !self.is_async
{ {
quote! { #native_call } quote! { #native_call }
} else if self.kind == FnKind::Constructor { } else if self.kind == FnKind::Constructor {
@ -55,18 +131,10 @@ impl TryToTokens for NapiFn {
if inner.load(std::sync::atomic::Ordering::Relaxed) { if inner.load(std::sync::atomic::Ordering::Relaxed) {
return std::ptr::null_mut(); return std::ptr::null_mut();
} }
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| { #function_call_inner
#(#arg_conversions)*
#native_call
})
} }
} else { } else {
quote! { function_call_inner
napi::bindgen_prelude::CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#(#arg_conversions)*
#native_call
})
}
}; };
let function_call = if self.catch_unwind { let function_call = if self.catch_unwind {
@ -109,20 +177,30 @@ impl TryToTokens for NapiFn {
} }
impl NapiFn { impl NapiFn {
fn gen_arg_conversions(&self) -> BindgenResult<(Vec<TokenStream>, Vec<TokenStream>)> { fn gen_arg_conversions(&self) -> BindgenResult<ArgConversions> {
let mut arg_conversions = vec![]; let mut arg_conversions = vec![];
let mut args = vec![]; let mut args = vec![];
let mut refs = vec![];
let mut mut_ref_spans = vec![];
let make_ref = |input| {
quote! {
_args_array[_arg_write_index] = _make_ref(::std::ptr::NonNull::new(#input).expect("ref ptr was null"));
_arg_write_index += 1;
}
};
// fetch this // fetch this
if let Some(parent) = &self.parent { if let Some(parent) = &self.parent {
match self.fn_self { match self.fn_self {
Some(FnSelf::Ref) => { Some(FnSelf::Ref) => {
refs.push(make_ref(quote! { cb.this }));
arg_conversions.push(quote! { arg_conversions.push(quote! {
let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? }; let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? };
let this: &#parent = Box::leak(Box::from_raw(this_ptr)); let this: &#parent = Box::leak(Box::from_raw(this_ptr));
}); });
} }
Some(FnSelf::MutRef) => { Some(FnSelf::MutRef) => {
refs.push(make_ref(quote! { cb.this }));
arg_conversions.push(quote! { arg_conversions.push(quote! {
let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? }; let this_ptr = unsafe { cb.unwrap_raw::<#parent>()? };
let this: &mut #parent = Box::leak(Box::from_raw(this_ptr)); let this: &mut #parent = Box::leak(Box::from_raw(this_ptr));
@ -215,7 +293,9 @@ impl NapiFn {
}) = elem.as_ref() }) = elem.as_ref()
{ {
if let Some(syn::PathSegment { ident, .. }) = segments.first() { if let Some(syn::PathSegment { ident, .. }) = segments.first() {
refs.push(make_ref(quote! { cb.this }));
let token = if mutability.is_some() { let token = if mutability.is_some() {
mut_ref_spans.push(generic_type.span());
quote! { <#ident as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.this)? } quote! { <#ident as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.this)? }
} else { } else {
quote! { <#ident as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.this)? } quote! { <#ident as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.this)? }
@ -228,15 +308,21 @@ impl NapiFn {
} }
} }
} }
args.push( refs.push(make_ref(quote! { cb.this }));
quote! { <napi::bindgen_prelude::This as napi::NapiValue>::from_raw_unchecked(env, cb.this) }, args.push(quote! { <napi::bindgen_prelude::This as napi::NapiValue>::from_raw_unchecked(env, cb.this) });
);
skipped_arg_count += 1; skipped_arg_count += 1;
continue; continue;
} }
} }
} }
arg_conversions.push(self.gen_ty_arg_conversion(&ident, i, path)); let (arg_conversion, arg_type) = self.gen_ty_arg_conversion(&ident, i, path);
if NapiArgType::MutRef == arg_type {
mut_ref_spans.push(path.ty.span());
}
if arg_type.is_ref() {
refs.push(make_ref(quote! { cb.get_arg(#i) }));
}
arg_conversions.push(arg_conversion);
args.push(quote! { #ident }); args.push(quote! { #ident });
} }
} }
@ -247,17 +333,24 @@ impl NapiFn {
} }
} }
Ok((arg_conversions, args)) Ok(ArgConversions {
arg_conversions,
args,
refs,
mut_ref_spans,
unsafe_: self.unsafe_,
})
} }
/// Returns a type conversion, and a boolean indicating whether this value needs to have a reference created to extend the lifetime
/// for async functions.
fn gen_ty_arg_conversion( fn gen_ty_arg_conversion(
&self, &self,
arg_name: &Ident, arg_name: &Ident,
index: usize, index: usize,
path: &syn::PatType, path: &syn::PatType,
) -> TokenStream { ) -> (TokenStream, NapiArgType) {
let ty = &*path.ty; let ty = &*path.ty;
let type_check = if self.return_if_invalid { let type_check = if self.return_if_invalid {
quote! { quote! {
if let Ok(maybe_promise) = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index)) { if let Ok(maybe_promise) = <#ty as napi::bindgen_prelude::ValidateNapiValue>::validate(env, cb.get_arg(#index)) {
@ -285,28 +378,31 @@ impl NapiFn {
elem, elem,
.. ..
}) => { }) => {
quote! { let q = quote! {
let #arg_name = { let #arg_name = {
#type_check #type_check
<#elem as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))? <#elem as napi::bindgen_prelude::FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))?
}; };
} };
(q, NapiArgType::MutRef)
} }
syn::Type::Reference(syn::TypeReference { elem, .. }) => { syn::Type::Reference(syn::TypeReference { elem, .. }) => {
quote! { let q = quote! {
let #arg_name = { let #arg_name = {
#type_check #type_check
<#elem as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))? <#elem as napi::bindgen_prelude::FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))?
}; };
} };
(q, NapiArgType::Ref)
} }
_ => { _ => {
quote! { let q = quote! {
let #arg_name = { let #arg_name = {
#type_check #type_check
<#ty as napi::bindgen_prelude::FromNapiValue>::from_napi_value(env, cb.get_arg(#index))? <#ty as napi::bindgen_prelude::FromNapiValue>::from_napi_value(env, cb.get_arg(#index))?
}; };
} };
(q, NapiArgType::Value)
} }
} }
} }
@ -482,3 +578,24 @@ impl NapiFn {
} }
} }
} }
struct ArgConversions {
pub args: Vec<TokenStream>,
pub arg_conversions: Vec<TokenStream>,
pub refs: Vec<TokenStream>,
pub mut_ref_spans: Vec<Span>,
pub unsafe_: bool,
}
#[derive(Debug, PartialEq, Eq)]
enum NapiArgType {
Ref,
MutRef,
Value,
}
impl NapiArgType {
fn is_ref(&self) -> bool {
matches!(self, NapiArgType::Ref | NapiArgType::MutRef)
}
}

View file

@ -696,6 +696,7 @@ fn napi_fn_from_decl(
enumerable: opts.enumerable(), enumerable: opts.enumerable(),
configurable: opts.configurable(), configurable: opts.configurable(),
catch_unwind: opts.catch_unwind().is_some(), catch_unwind: opts.catch_unwind().is_some(),
unsafe_: sig.unsafety.is_some(),
} }
}) })
} }

View file

@ -247,6 +247,7 @@ Generated by [AVA](https://avajs.dev).
name: string␊ name: string␊
constructor(name: string)␊ constructor(name: string)␊
getCount(): number␊ getCount(): number␊
getNameAsync(): Promise<string>
}␊ }␊
export type Blake2bHasher = Blake2BHasher␊ export type Blake2bHasher = Blake2BHasher␊
/** Smoking test for type generation */␊ /** Smoking test for type generation */␊

View file

@ -206,6 +206,11 @@ test('class', (t) => {
}) })
}) })
test('async self in class', async (t) => {
const b = new Bird('foo')
t.is(await b.getNameAsync(), 'foo')
})
test('class factory', (t) => { test('class factory', (t) => {
const duck = ClassWithFactory.withName('Default') const duck = ClassWithFactory.withName('Default')
t.is(duck.name, 'Default') t.is(duck.name, 'Default')

View file

@ -237,6 +237,7 @@ export class Bird {
name: string name: string
constructor(name: string) constructor(name: string)
getCount(): number getCount(): number
getNameAsync(): Promise<string>
} }
export type Blake2bHasher = Blake2BHasher export type Blake2bHasher = Blake2BHasher
/** Smoking test for type generation */ /** Smoking test for type generation */

View file

@ -123,6 +123,12 @@ impl Bird {
pub fn get_count(&self) -> u32 { pub fn get_count(&self) -> u32 {
1234 1234
} }
#[napi]
pub async fn get_name_async(&self) -> &str {
tokio::time::sleep(std::time::Duration::new(1, 0)).await;
self.name.as_str()
}
} }
/// Smoking test for type generation /// Smoking test for type generation