diff --git a/starknet-core-derive/src/generics_visitor.rs b/starknet-core-derive/src/generics_visitor.rs new file mode 100644 index 00000000..844a0866 --- /dev/null +++ b/starknet-core-derive/src/generics_visitor.rs @@ -0,0 +1,195 @@ +use std::collections::HashSet; + +use syn::{punctuated::Pair, Generics, Path, Token, Type, WhereClause}; + +// Adapted from https://github.com/serde-rs/serde/blob/1d7899d671c6f6155b63a39fa6001c9c48260821/serde_derive/src/bound.rs#L91 + +pub struct GenericsVisitor<'ast> { + existing_generics: Generics, + + // Set of all generic type parameters on the current struct. + // Initialized up front. + all_type_params: HashSet, + + // Set of generic type parameters used in fields. + // Filled in as the visitor sees them. + relevant_type_params: HashSet, + + // Fields whose type is an associated type of one of the generic type + // parameters. + associated_type_usage: Vec<&'ast syn::TypePath>, +} + +impl<'ast> GenericsVisitor<'ast> { + pub fn new(existing_generics: &Generics) -> Self { + Self { + existing_generics: existing_generics.clone(), + all_type_params: existing_generics + .type_params() + .map(|param| param.ident.clone()) + .collect(), + relevant_type_params: HashSet::default(), + associated_type_usage: Vec::default(), + } + } + + pub fn extend_where_clause(self, where_clause: &mut WhereClause, bound: &Path) { + where_clause.predicates.extend( + self.existing_generics + .type_params() + .filter_map(|param| { + self.relevant_type_params + .contains(¶m.ident) + .then(|| syn::TypePath { + qself: None, + path: param.ident.clone().into(), + }) + }) + .chain(self.associated_type_usage.into_iter().cloned()) + .map(|bounded_ty| { + syn::WherePredicate::Type(syn::PredicateType { + lifetimes: None, + bounded_ty: syn::Type::Path(bounded_ty), + colon_token: ::default(), + bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: bound.clone(), + })] + .into_iter() + .collect(), + }) + }), + ); + } + + pub fn visit_field(&mut self, field: &'ast syn::Field) { + if let syn::Type::Path(ty) = ungroup(&field.ty) { + if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() { + if self.all_type_params.contains(&t.ident) { + self.associated_type_usage.push(ty); + } + } + } + self.visit_type(&field.ty); + } + + fn visit_path(&mut self, path: &'ast syn::Path) { + if path.leading_colon.is_none() && path.segments.len() == 1 { + let id = &path.segments[0].ident; + if self.all_type_params.contains(id) { + self.relevant_type_params.insert(id.clone()); + } + } + for segment in &path.segments { + self.visit_path_segment(segment); + } + } + + // Everything below is simply traversing the syntax tree. + + fn visit_type(&mut self, ty: &'ast syn::Type) { + match ty { + syn::Type::Array(ty) => self.visit_type(&ty.elem), + syn::Type::BareFn(ty) => { + for arg in &ty.inputs { + self.visit_type(&arg.ty); + } + self.visit_return_type(&ty.output); + } + syn::Type::Group(ty) => self.visit_type(&ty.elem), + syn::Type::ImplTrait(ty) => { + for bound in &ty.bounds { + self.visit_type_param_bound(bound); + } + } + syn::Type::Macro(ty) => self.visit_macro(&ty.mac), + syn::Type::Paren(ty) => self.visit_type(&ty.elem), + syn::Type::Path(ty) => { + if let Some(qself) = &ty.qself { + self.visit_type(&qself.ty); + } + self.visit_path(&ty.path); + } + syn::Type::Ptr(ty) => self.visit_type(&ty.elem), + syn::Type::Reference(ty) => self.visit_type(&ty.elem), + syn::Type::Slice(ty) => self.visit_type(&ty.elem), + syn::Type::TraitObject(ty) => { + for bound in &ty.bounds { + self.visit_type_param_bound(bound); + } + } + syn::Type::Tuple(ty) => { + for elem in &ty.elems { + self.visit_type(elem); + } + } + + syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {} + + _ => {} + } + } + + fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) { + self.visit_path_arguments(&segment.arguments); + } + + fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) { + match arguments { + syn::PathArguments::None => {} + syn::PathArguments::AngleBracketed(arguments) => { + for arg in &arguments.args { + match arg { + syn::GenericArgument::Type(arg) => self.visit_type(arg), + syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty), + syn::GenericArgument::Lifetime(_) + | syn::GenericArgument::Const(_) + | syn::GenericArgument::AssocConst(_) + | syn::GenericArgument::Constraint(_) => {} + _ => {} + } + } + } + syn::PathArguments::Parenthesized(arguments) => { + for argument in &arguments.inputs { + self.visit_type(argument); + } + self.visit_return_type(&arguments.output); + } + } + } + + fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) { + match return_type { + syn::ReturnType::Default => {} + syn::ReturnType::Type(_, output) => self.visit_type(output), + } + } + + fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) { + match bound { + syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path), + syn::TypeParamBound::Lifetime(_) + | syn::TypeParamBound::PreciseCapture(_) + | syn::TypeParamBound::Verbatim(_) => {} + _ => {} + } + } + + // Type parameter should not be considered used by a macro path. + // + // struct TypeMacro { + // mac: T!(), + // marker: PhantomData, + // } + fn visit_macro(&mut self, _mac: &'ast syn::Macro) {} +} + +fn ungroup(mut ty: &Type) -> &Type { + while let Type::Group(group) = ty { + ty = &group.elem; + } + ty +} diff --git a/starknet-core-derive/src/lib.rs b/starknet-core-derive/src/lib.rs index 7045b6bb..d89f243e 100644 --- a/starknet-core-derive/src/lib.rs +++ b/starknet-core-derive/src/lib.rs @@ -7,9 +7,16 @@ use proc_macro2::Span; use quote::quote; use syn::{ parse::{Error as ParseError, Parse, ParseStream}, - parse_macro_input, DeriveInput, Fields, LitInt, LitStr, Meta, Token, + parse_macro_input, parse_quote, + punctuated::Punctuated, + spanned::Spanned, + DeriveInput, Fields, Lifetime, LifetimeParam, LitInt, LitStr, Meta, Path, Token, }; +use crate::generics_visitor::GenericsVisitor; + +mod generics_visitor; + #[derive(Default)] struct Args { core: Option, @@ -65,14 +72,18 @@ mod kw { /// Derives the `Encode` trait. #[proc_macro_derive(Encode, attributes(starknet))] pub fn derive_encode(input: TokenStream) -> TokenStream { - let input: DeriveInput = parse_macro_input!(input); + let mut input: DeriveInput = parse_macro_input!(input); let ident = &input.ident; let core = derive_core_path(&input); - let impl_block = match input.data { + let mut visitor = GenericsVisitor::new(&input.generics); + + let impl_block = match &input.data { syn::Data::Struct(data) => { let field_impls = data.fields.iter().enumerate().map(|(ind_field, field)| { + visitor.visit_field(field); + let field_ident = match &field.ident { Some(field_ident) => quote! { self.#field_ident }, None => { @@ -108,6 +119,8 @@ pub fn derive_encode(input: TokenStream) -> TokenStream { .map(|field| field.ident.as_ref().unwrap()); let field_impls = fields_named.named.iter().map(|field| { + visitor.visit_field(field); + let field_ident = field.ident.as_ref().unwrap(); let field_type = &field.ty; @@ -136,6 +149,8 @@ pub fn derive_encode(input: TokenStream) -> TokenStream { let field_impls = fields_unnamed.unnamed.iter().enumerate().map( |(ind_field, field)| { + visitor.visit_field(field); + let field_ident = syn::Ident::new( &format!("field_{ind_field}"), Span::call_site(), @@ -175,9 +190,15 @@ pub fn derive_encode(input: TokenStream) -> TokenStream { syn::Data::Union(_) => panic!("union type not supported"), }; + let encode_path = parse_quote!(#core::codec::Encode); + + visitor.extend_where_clause(input.generics.make_where_clause(), &encode_path); + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + quote! { #[automatically_derived] - impl #core::codec::Encode for #ident { + impl #impl_generics #encode_path for #ident #ty_generics #where_clause { fn encode(&self, writer: &mut W) -> ::core::result::Result<(), #core::codec::Error> { #impl_block @@ -189,18 +210,39 @@ pub fn derive_encode(input: TokenStream) -> TokenStream { .into() } +const DECODE_LIFETIME_IDENT: &'static str = "'de"; + /// Derives the `Decode` trait. #[proc_macro_derive(Decode, attributes(starknet))] pub fn derive_decode(input: TokenStream) -> TokenStream { - let input: DeriveInput = parse_macro_input!(input); - let ident = &input.ident; + let mut input: DeriveInput = parse_macro_input!(input); + + if let Some(lt) = input + .generics + .lifetimes() + .find(|lt| lt.lifetime.ident == DECODE_LIFETIME_IDENT) + { + return syn::Error::new( + lt.span(), + format!( + "cannot decode when there is a lifetime parameter called {DECODE_LIFETIME_IDENT}" + ), + ) + .into_compile_error() + .into(); + } + let ident = &input.ident; let core = derive_core_path(&input); - let impl_block = match input.data { + let mut visitor = GenericsVisitor::new(&input.generics); + + let impl_block = match &input.data { syn::Data::Struct(data) => match &data.fields { Fields::Named(fields_named) => { let field_impls = fields_named.named.iter().map(|field| { + visitor.visit_field(field); + let field_ident = &field.ident; let field_type = &field.ty; @@ -218,6 +260,8 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { } Fields::Unnamed(fields_unnamed) => { let field_impls = fields_unnamed.unnamed.iter().map(|field| { + visitor.visit_field(field); + let field_type = &field.ty; quote! { <#field_type as #core::codec::Decode>::decode_iter(iter)? @@ -248,6 +292,8 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { let decode_impl = match &variant.fields { Fields::Named(fields_named) => { let field_impls = fields_named.named.iter().map(|field| { + visitor.visit_field(field); + let field_ident = field.ident.as_ref().unwrap(); let field_type = &field.ty; @@ -265,6 +311,8 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { } Fields::Unnamed(fields_unnamed) => { let field_impls = fields_unnamed.unnamed.iter().map(|field| { + visitor.visit_field(field); + let field_type = &field.ty; quote! { @@ -303,12 +351,35 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { syn::Data::Union(_) => panic!("union type not supported"), }; + let decode_path: Path = parse_quote!(#core::codec::Decode); + let de_lifetime = Lifetime::new(DECODE_LIFETIME_IDENT, Span::call_site()); + + visitor.extend_where_clause( + input.generics.make_where_clause(), + &parse_quote!(#decode_path<#de_lifetime>), + ); + + let generics_without_decode_lt = input.generics.clone(); + let (_, ty_generics, where_clause) = generics_without_decode_lt.split_for_impl(); + + input + .generics + .params + .push(syn::GenericParam::Lifetime(LifetimeParam { + attrs: vec![], + lifetime: de_lifetime.clone(), + bounds: Punctuated::new(), + colon_token: None, + })); + + let (impl_generics, _, _) = input.generics.split_for_impl(); + quote! { #[automatically_derived] - impl<'a> #core::codec::Decode<'a> for #ident { - fn decode_iter(iter: &mut T) -> ::core::result::Result + impl #impl_generics #decode_path<#de_lifetime> for #ident #ty_generics #where_clause { + fn decode_iter<__T>(iter: &mut __T) -> ::core::result::Result where - T: core::iter::Iterator + __T: core::iter::Iterator { #impl_block } @@ -318,7 +389,7 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { } /// Determines the path to the `starknet-core` crate root. -fn derive_core_path(input: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_core_path(input: &DeriveInput) -> Path { let mut attr_args = Args::default(); for attr in &input.attrs { @@ -342,14 +413,14 @@ fn derive_core_path(input: &DeriveInput) -> proc_macro2::TokenStream { attr_args.core.map_or_else( || { #[cfg(not(feature = "import_from_starknet"))] - quote! { + parse_quote! { ::starknet_core } // This feature is enabled by the `starknet` crate. When using `starknet` it's assumed // that users would not have imported `starknet-core` directly. #[cfg(feature = "import_from_starknet")] - quote! { + parse_quote! { ::starknet::core } }, @@ -358,7 +429,7 @@ fn derive_core_path(input: &DeriveInput) -> proc_macro2::TokenStream { } /// Turns an integer into an optimal `TokenStream` that constructs a `Felt` with the same value. -fn int_to_felt(int: usize, core: &proc_macro2::TokenStream) -> proc_macro2::TokenStream { +fn int_to_felt(int: usize, core: &Path) -> proc_macro2::TokenStream { match int { 0 => quote! { #core::types::Felt::ZERO }, 1 => quote! { #core::types::Felt::ONE }, diff --git a/starknet-core/src/codec.rs b/starknet-core/src/codec.rs index 08d32164..e6348afe 100644 --- a/starknet-core/src/codec.rs +++ b/starknet-core/src/codec.rs @@ -747,6 +747,80 @@ mod tests { ); } + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_encode_struct_named_generic() { + #[derive(Encode)] + #[starknet(core = "crate")] + struct CairoType + where + B: Eq, + { + a: A::Item, + b: B, + } + + let mut serialized = Vec::new(); + CairoType::, _> { a: 10u128, b: true } + .encode(&mut serialized) + .unwrap(); + assert_eq!( + serialized, + vec![Felt::from_str("10").unwrap(), Felt::from_str("1").unwrap(),] + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_encode_struct_tuple_generic() { + #[derive(Encode)] + #[starknet(core = "crate")] + struct CairoType(A::Item, B) + where + B: Eq; + + let mut serialized = Vec::new(); + CairoType::, _>(10u128, true) + .encode(&mut serialized) + .unwrap(); + assert_eq!( + serialized, + vec![Felt::from_str("10").unwrap(), Felt::from_str("1").unwrap(),] + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_encode_enum_generic() { + #[derive(Encode)] + #[starknet(core = "crate")] + enum CairoType + where + B: Eq, + { + A(A::Item), + B(B), + } + + let mut serialized = Vec::::new(); + CairoType::, bool>::A(10u128) + .encode(&mut serialized) + .unwrap(); + assert_eq!( + serialized, + vec![Felt::from_str("0").unwrap(), Felt::from_str("10").unwrap()] + ); + + serialized.clear(); + CairoType::, _>::B(true) + .encode(&mut serialized) + .unwrap(); + assert_eq!( + serialized, + vec![Felt::from_str("1").unwrap(), Felt::from_str("1").unwrap()] + ); + } + #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] fn test_derive_encode_struct_named() { @@ -1158,4 +1232,66 @@ mod tests { .unwrap() ); } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_decode_struct_named_generic() { + #[derive(Debug, PartialEq, Eq, Decode)] + #[starknet(core = "crate")] + struct CairoType + where + B: Eq, + { + a: A::Item, + b: B, + } + + assert_eq!( + CairoType::, _> { a: 10u128, b: true }, + CairoType::decode(&[Felt::from_str("10").unwrap(), Felt::from_str("1").unwrap()]) + .unwrap() + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_decode_struct_tuple_generic() { + #[derive(Debug, PartialEq, Eq, Decode)] + #[starknet(core = "crate")] + struct CairoType(A::Item, B) + where + B: Eq; + + assert_eq!( + CairoType::, _>(10u128, true), + CairoType::decode(&[Felt::from_str("10").unwrap(), Felt::from_str("1").unwrap(),]) + .unwrap() + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_derive_decode_enum_generic() { + #[derive(Debug, PartialEq, Eq, Decode)] + #[starknet(core = "crate")] + enum CairoType + where + B: Eq, + { + A(A::Item), + B { b: B }, + } + + assert_eq!( + CairoType::, bool>::A(10u128), + CairoType::decode(&[Felt::from_str("0").unwrap(), Felt::from_str("10").unwrap(),]) + .unwrap() + ); + + assert_eq!( + CairoType::, _>::B { b: true }, + CairoType::decode(&[Felt::from_str("1").unwrap(), Felt::from_str("1").unwrap(),]) + .unwrap() + ); + } }