support async functions

This commit is contained in:
forehalo 2021-10-25 00:00:31 +08:00 committed by LongYinan
parent cf0b5785cd
commit 0ee80662be
15 changed files with 138 additions and 66 deletions

View file

@ -7,7 +7,8 @@ pub struct NapiFn {
pub js_name: String,
pub attrs: Vec<Attribute>,
pub args: Vec<NapiFnArgKind>,
pub ret: Option<(syn::Type, /* is_result */ bool)>,
pub ret: Option<syn::Type>,
pub is_ret_result: bool,
pub is_async: bool,
pub fn_self: Option<FnSelf>,
pub kind: FnKind,

View file

@ -17,27 +17,39 @@ impl TryToTokens for NapiFn {
let receiver_ret_name = Ident::new("_ret", Span::call_site());
let ret = self.gen_fn_return(&receiver_ret_name);
let register = self.gen_fn_register();
let attrs = &self.attrs;
let function_call_tokens = if args_len == 0 && self.fn_self.is_none() {
let native_call = if !self.is_async {
quote! {
{
let #receiver_ret_name = #receiver();
#ret
}
}
} else {
quote! {
CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#(#arg_conversions)*
let #receiver_ret_name = {
#receiver(#(#arg_names),*)
};
#ret
}
} else {
let call = if self.is_ret_result {
quote! { #receiver(#(#arg_names),*).await }
} else {
quote! { Ok(#receiver(#(#arg_names),*).await) }
};
quote! {
execute_tokio_future(env, async { #call }, |env, #receiver_ret_name| {
#ret
})
}
};
let function_call = if args_len == 0 && self.fn_self.is_none() {
quote! { #native_call }
} else {
quote! {
CallbackInfo::<#args_len>::new(env, cb, None).and_then(|mut cb| {
#(#arg_conversions)*
#native_call
})
}
};
(quote! {
#(#attrs)*
#[doc(hidden)]
@ -48,7 +60,7 @@ impl TryToTokens for NapiFn {
cb: sys::napi_callback_info
) -> sys::napi_value {
unsafe {
#function_call_tokens.unwrap_or_else(|e| {
#function_call.unwrap_or_else(|e| {
JsError::from(e).throw_into(env);
std::ptr::null_mut::<sys::napi_value__>()
})
@ -120,12 +132,12 @@ impl NapiFn {
..
}) => {
quote! {
let #arg_name = unsafe { <#elem as FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))? };
let #arg_name = <#elem as FromNapiMutRef>::from_napi_mut_ref(env, cb.get_arg(#index))?;
}
}
syn::Type::Reference(syn::TypeReference { elem, .. }) => {
quote! {
let #arg_name = unsafe { <#elem as FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))? };
let #arg_name = <#elem as FromNapiRef>::from_napi_ref(env, cb.get_arg(#index))?;
}
}
_ => {
@ -138,7 +150,7 @@ impl NapiFn {
};
quote! {
let #arg_name = unsafe {
let #arg_name = {
#type_check
<#ty as FromNapiValue>::from_napi_value(env, cb.get_arg(#index))?
};
@ -212,20 +224,25 @@ impl NapiFn {
fn gen_fn_return(&self, ret: &Ident) -> TokenStream {
let js_name = &self.js_name;
let ret_ty = &self.ret;
if let Some((ref ty, is_result)) = ret_ty {
if let Some(ty) = &self.ret {
if self.kind == FnKind::Constructor {
quote! { cb.construct(#js_name, #ret) }
} else if *is_result {
} else if self.is_ret_result {
if self.is_async {
quote! {
<#ty as ToNapiValue>::to_napi_value(env, #ret)
}
} else {
quote! {
if #ret.is_ok() {
<#ty as ToNapiValue>::to_napi_value(env, #ret)
<Result<#ty> as ToNapiValue>::to_napi_value(env, #ret)
} else {
JsError::from(#ret.unwrap_err()).throw_into(env);
Ok(std::ptr::null_mut())
}
}
}
} else {
quote! {
<#ty as ToNapiValue>::to_napi_value(env, #ret)

View file

@ -54,6 +54,8 @@ static KNOWN_TYPES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
("Value", "any"),
("Map", "Record<string, any>"),
("HashMap", "Record<{}, {}>"),
("Buffer", "Buffer"),
// TODO: Vec<u8> should be Buffer, now is Array<number>
("Vec", "Array<{}>"),
("Option", "{} | null"),
("Result", "Error | {}"),

View file

@ -87,8 +87,8 @@ impl NapiFn {
match self.kind {
FnKind::Constructor | FnKind::Setter => "".to_owned(),
_ => {
let ret = if let Some((ref ret, is_result)) = self.ret {
let ts_type = ty_to_ts_type(ret, is_result);
let ret = if let Some(ret) = &self.ret {
let ts_type = ty_to_ts_type(ret, true);
if ts_type == "undefined" {
"void".to_owned()
} else {

View file

@ -194,7 +194,7 @@ fn extract_path_ident(path: &syn::Path) -> BindgenResult<Ident> {
}
}
fn extract_fn_types(
fn extract_callback_trait_types(
arguments: &syn::PathArguments,
) -> BindgenResult<(Vec<syn::Type>, Option<syn::Type>)> {
match arguments {
@ -209,39 +209,14 @@ fn extract_fn_types(
let ret = match &arguments.output {
syn::ReturnType::Type(_, ret_ty) => {
let ret_ty = &**ret_ty;
match ret_ty {
syn::Type::Path(syn::TypePath {
qself: None,
ref path,
}) if path.segments.len() == 1 => {
let segment = path.segments.first().unwrap();
if segment.ident != "Result" {
bail_span!(ret_ty, "The return type of callback can only be `Result`");
} else {
match &segment.arguments {
syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args,
..
}) => {
// fast test
if args.to_token_stream().to_string() == "()" {
if let Some(ty_of_result) = extract_result_ty(ret_ty)? {
if ty_of_result.to_token_stream().to_string() == "()" {
None
} else {
let ok_arg = args.first().unwrap();
match ok_arg {
syn::GenericArgument::Type(ty) => Some(ty.clone()),
_ => bail_span!(ok_arg, "unsupported generic type"),
Some(ty_of_result)
}
}
}
_ => {
bail_span!(segment, "Too many arguments")
}
}
}
}
_ => bail_span!(ret_ty, "The return type of callback can only be `Result`"),
} else {
bail_span!(ret_ty, "The return type of callback can only be `Result`");
}
}
_ => {
@ -257,6 +232,34 @@ fn extract_fn_types(
}
}
fn extract_result_ty(ty: &syn::Type) -> BindgenResult<Option<syn::Type>> {
match ty {
syn::Type::Path(syn::TypePath { qself: None, path }) if path.segments.len() == 1 => {
let segment = path.segments.first().unwrap();
if segment.ident != "Result" {
Ok(None)
} else {
match &segment.arguments {
syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args, ..
}) => {
let ok_arg = args.first().unwrap();
match ok_arg {
syn::GenericArgument::Type(ty) => Ok(Some(ty.clone())),
_ => bail_span!(ok_arg, "unsupported generic type"),
}
}
_ => {
bail_span!(segment, "unsupported generic type")
}
}
}
}
_ => Ok(None),
}
}
fn get_expr(mut expr: &syn::Expr) -> &syn::Expr {
while let syn::Expr::Group(g) = expr {
expr = &g.expr;
@ -480,7 +483,7 @@ fn napi_fn_from_decl(
syn::FnArg::Typed(mut p) => {
let ty_str = p.ty.to_token_stream().to_string();
if let Some(path_arguments) = callback_traits.get(&ty_str) {
match extract_fn_types(path_arguments) {
match extract_callback_trait_types(path_arguments) {
Ok((fn_args, fn_ret)) => Some(NapiFnArgKind::Callback(Box::new(CallbackArg {
pat: p.pat,
args: fn_args,
@ -518,11 +521,15 @@ fn napi_fn_from_decl(
})
.collect::<Vec<_>>();
let ret = match output {
syn::ReturnType::Default => None,
let (ret, is_ret_result) = match output {
syn::ReturnType::Default => (None, false),
syn::ReturnType::Type(_, ty) => {
let is_result = ty.to_token_stream().to_string().starts_with("Result <");
Some((replace_self(*ty, parent), is_result))
let result_ty = extract_result_ty(&ty)?;
if result_ty.is_some() {
(result_ty, true)
} else {
(Some(replace_self(*ty, parent)), false)
}
}
};
@ -559,6 +566,7 @@ fn napi_fn_from_decl(
js_name,
args,
ret,
is_ret_result,
is_async: asyncness.is_some(),
vis,
kind: fn_kind(opts),

View file

@ -23,6 +23,8 @@ napi7 = ["napi6", "napi-sys/napi7"]
napi8 = ["napi7", "napi-sys/napi8"]
serde-json = ["serde", "serde_json"]
tokio_rt = ["tokio", "once_cell", "napi4"]
async = ["tokio_rt"]
full = ["latin1", "napi8", "async", "serde-json"]
[dependencies]
ctor = "0.1"

View file

@ -103,3 +103,9 @@ impl ValidateNapiValue for Buffer {
vec![ValueType::Object]
}
}
impl ToNapiValue for Vec<u8> {
unsafe fn to_napi_value(env: sys::napi_env, val: Self) -> Result<sys::napi_value> {
Buffer::to_napi_value(env, val.into())
}
}

View file

@ -158,6 +158,7 @@ macro_rules! assert_type_of {
pub mod bindgen_prelude {
#[cfg(feature = "compat-mode")]
pub use crate::bindgen_runtime::register_module_exports;
pub use crate::tokio_runtime::*;
pub use crate::{
assert_type_of, bindgen_runtime::*, check_status, check_status_or_throw, error, error::*, sys,
type_of, JsError, Property, PropertyAttributes, Result, Status, Task, ValueType,

View file

@ -13,10 +13,12 @@ napi3 = ["napi/napi3"]
[dependencies]
napi-derive = { path = "../../crates/macro", features = ["type-def"] }
napi = { path = "../../crates/napi", features = ["latin1", "serde-json"] }
napi = { path = "../../crates/napi", features = ["full"] }
serde = "1"
serde_derive = "1"
serde_json = "1"
tokio = {version = "1", features = ["default", "fs"]}
futures = "0.3"
[build-dependencies]
napi-build = { path = "../../crates/build" }

View file

@ -11,6 +11,7 @@ Generated by [AVA](https://avajs.dev).
`export function getWords(): Array<string>
export function getNums(): Array<number>
export function sumNums(nums: Array<number>): number␊
export function readFileAsync(path: string): Promise<Array<number>>␊
export function getCwd(callback: (arg0: string) => void): void␊
export function readFile(callback: (arg0: Error | undefined, arg1: string | null) => void): void␊
export enum Kind { Dog = 0, Cat = 1, Duck = 2 }␊

View file

@ -1,3 +1,5 @@
import { join } from 'path'
import test from 'ava'
import {
@ -23,6 +25,7 @@ import {
readPackageJson,
getPackageJsonName,
getBuffer,
readFileAsync,
} from '../'
test('number', (t) => {
@ -117,3 +120,13 @@ test('serde-json', (t) => {
test('buffer', (t) => {
t.is(getBuffer().toString('utf-8'), 'Hello world')
})
test('async', async (t) => {
const bufPromise = readFileAsync(join(__dirname, '../package.json'))
await t.notThrowsAsync(bufPromise)
const buf = await bufPromise
const { name } = JSON.parse(buf.toString())
t.is(name, 'napi-examples')
await t.throwsAsync(() => readFileAsync('some_nonexist_path.file'))
})

View file

@ -1,6 +1,7 @@
export function getWords(): Array<string>
export function getNums(): Array<number>
export function sumNums(nums: Array<number>): number
export function readFileAsync(path: string): Promise<Array<number>>
export function getCwd(callback: (arg0: string) => void): void
export function readFile(callback: (arg0: Error | undefined, arg1: string | null) => void): void
export enum Kind { Dog = 0, Cat = 1, Duck = 2 }

View file

@ -0,0 +1,17 @@
use futures::prelude::*;
use napi::bindgen_prelude::*;
use tokio::fs;
#[napi]
async fn read_file_async(path: String) -> Result<Vec<u8>> {
fs::read(path)
.map(|v| {
v.map_err(|e| {
Error::new(
Status::GenericFailure,
format!("failed to read file, {}", e),
)
})
})
.await
}

View file

@ -4,6 +4,7 @@ extern crate napi_derive;
extern crate serde_derive;
mod array;
mod r#async;
mod callback;
mod class;
mod r#enum;