diff --git a/crates/relax-macros/src/lib.rs b/crates/relax-macros/src/lib.rs index cab40ff..e2d2710 100644 --- a/crates/relax-macros/src/lib.rs +++ b/crates/relax-macros/src/lib.rs @@ -135,21 +135,15 @@ fn derive_impl(input: syn::DeriveInput) -> syn::Result { let vis = &field.vis; let ty = &field.ty; - let relax_attr = field - .attrs - .iter() - .find(|attr| attr.path().is_ident("relax")); - - if relax_attr.is_none() { + if field.attrs.iter().any(|attr| attr.path().is_ident("relax")) { + // nested + Ok(quote! { #vis #name: ::std::option::Option<<#ty as ::relax::Relax>::Partial> }) + } else { // not nested match get_generic_ty("Option", ty) { Some(ty) => Ok(quote! { #vis #name: ::std::option::Option<#ty> }), None => Ok(quote! { #vis #name: ::std::option::Option<#ty> }), } - } else { - // nested - let ty = get_ty_name_from_helper_attr(relax_attr.unwrap())?; - Ok(quote! { #vis #name: ::std::option::Option<#ty> }) } }) .collect::>>()?; @@ -169,18 +163,16 @@ fn derive_impl(input: syn::DeriveInput) -> syn::Result { let try_from_fields = fields.iter().map(|field| -> syn::Result { let name = field.ident.as_ref(); let ty = &field.ty; - let relaxed_ty = field + let nested = field .attrs .iter() - .find(|attr| attr.path().is_ident("relax")) - .map(get_ty_name_from_helper_attr) - .transpose()?; + .any(|attr| attr.path().is_ident("relax")); - Ok(match (relaxed_ty, get_generic_ty("Option", ty)) { - (None, None) => quote! { #name: value.#name.ok_or(::relax::RequiredFieldNotSet(stringify!(#name)))? }, - (None, Some(_)) => quote! { #name: value.#name }, - (Some(_), None) => quote! { #name: value.#name.ok_or(::relax::RequiredFieldNotSet(stringify!(#name)))?.try_into()? }, - (Some(_), Some(_)) => quote! { #name: value.#name.map(|val| val.try_into()).transpose().ok().flatten() }, + Ok(match (nested, get_generic_ty("Option", ty).is_some()) { + (false, false) => quote! { #name: value.#name.ok_or(::relax::MissingRequiredField(stringify!(#name)))? }, + (false, true) => quote! { #name: value.#name }, + (true, false) => quote! { #name: value.#name.ok_or(::relax::MissingRequiredField(stringify!(#name)))?.try_into()? }, + (true, true) => quote! { #name: value.#name.map(|val| val.try_into()).transpose().ok().flatten() }, }) }).collect::>>()?; @@ -204,7 +196,7 @@ fn derive_impl(input: syn::DeriveInput) -> syn::Result { } impl #generics ::std::convert::TryFrom<#partial #generics> for #base #generics { - type Error = ::relax::RequiredFieldNotSet; + type Error = ::relax::MissingRequiredField; fn try_from(value: #partial) -> ::std::result::Result { Ok(Self{ @@ -250,29 +242,3 @@ fn get_generic_ty<'a>(wrapper: &str, ty: &'a syn::Type) -> Option<&'a syn::Type> None } } - -fn get_ty_name_from_helper_attr(attr: &syn::Attribute) -> syn::Result { - let tokens = match attr { - syn::Attribute { - meta: syn::Meta::List(syn::MetaList { ref tokens, .. }), - .. - } => tokens, - _ => { - return Err(syn::Error::new_spanned( - attr, - "helper attribute should be #[relax(StructName)]", - )) - } - }; - - let tokens: Vec = tokens.clone().into_iter().collect(); - - if tokens.len() != 1 { - return Err(syn::Error::new_spanned( - attr, - "helper attribute should be #[relax(StructName)]", - )); - } - - Ok(tokens[0].to_owned()) -}