diff --git a/enum-iterator-derive/Cargo.toml b/enum-iterator-derive/Cargo.toml index d0c8c51..548d3d1 100644 --- a/enum-iterator-derive/Cargo.toml +++ b/enum-iterator-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "enum-iterator-derive" -version = "1.5.0" +version = "1.6.0" authors = ["Stephane Raux "] edition = "2021" description = "Procedural macro to derive Sequence" diff --git a/enum-iterator-derive/src/lib.rs b/enum-iterator-derive/src/lib.rs index 173e503..91778a8 100644 --- a/enum-iterator-derive/src/lib.rs +++ b/enum-iterator-derive/src/lib.rs @@ -45,11 +45,13 @@ fn derive(input: proc_macro::TokenStream) -> Result { #[derive(Debug)] struct DeriveOptions { crate_path: Path, + use_default: bool, } impl DeriveOptions { fn parse(attrs: &[syn::Attribute]) -> Result { let mut crate_path = None; + let mut use_default = false; attrs .iter() .filter(|attr| attr.path().is_ident("enum_iterator")) @@ -63,6 +65,9 @@ impl DeriveOptions { } else { Err(meta.error("duplicate crate key")) } + } else if meta.path.is_ident("use_default") { + use_default = true; + Ok(()) } else { Err(meta.error(format!("unknown key {}", meta.path.to_token_stream()))) } @@ -78,6 +83,7 @@ impl DeriveOptions { .into_iter() .collect(), }), + use_default, }) } } @@ -87,7 +93,12 @@ fn derive_for_ast(ast: DeriveInput) -> Result { let generics = &ast.generics; let options = DeriveOptions::parse(&ast.attrs)?; match &ast.data { - syn::Data::Struct(s) => derive_for_struct(&options, ty, generics, &s.fields), + syn::Data::Struct(s) => { + if options.use_default { + return Err(Error::DefaultForStruct.with_tokens(ast)); + } + derive_for_struct(&options, ty, generics, &s.fields) + } syn::Data::Enum(e) => derive_for_enum(&options, ty, generics, &e.variants), syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)), } @@ -100,9 +111,9 @@ fn derive_for_struct( fields: &Fields, ) -> Result { let crate_path = &options.crate_path; - let cardinality = tuple_cardinality(&options.crate_path, fields); - let first = init_value(&options.crate_path, ty, None, fields, Direction::Forward); - let last = init_value(&options.crate_path, ty, None, fields, Direction::Backward); + let cardinality = tuple_cardinality(options, fields); + let first = init_value(options, ty, None, fields, Direction::Forward); + let last = init_value(options, ty, None, fields, Direction::Backward); let next_body = advance_struct(&options.crate_path, ty, fields, Direction::Forward); let previous_body = advance_struct(&options.crate_path, ty, fields, Direction::Backward); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); @@ -153,9 +164,9 @@ fn derive_for_enum( generics: &Generics, variants: &Punctuated, ) -> Result { - let cardinality = enum_cardinality(&options.crate_path, variants); - let next_body = advance_enum(&options.crate_path, ty, variants, Direction::Forward); - let previous_body = advance_enum(&options.crate_path, ty, variants, Direction::Backward); + let cardinality = enum_cardinality(options, variants); + let next_body = advance_enum(options, ty, variants, Direction::Forward); + let previous_body = advance_enum(options, ty, variants, Direction::Backward); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let where_clause = if generics.params.is_empty() { where_clause.cloned() @@ -175,9 +186,8 @@ fn derive_for_enum( ); Some(clause) }; - let next_variant_body = next_variant(&options.crate_path, ty, variants, Direction::Forward); - let previous_variant_body = - next_variant(&options.crate_path, ty, variants, Direction::Backward); + let next_variant_body = next_variant(options, ty, variants, Direction::Forward); + let previous_variant_body = next_variant(options, ty, variants, Direction::Backward); let (first, last) = if variants.is_empty() { ( quote! { ::core::option::Option::None }, @@ -231,16 +241,21 @@ fn derive_for_enum( Ok(tokens) } -fn enum_cardinality(crate_path: &Path, variants: &Punctuated) -> TokenStream { +fn enum_cardinality(options: &DeriveOptions, variants: &Punctuated) -> TokenStream { let terms = variants .iter() - .map(|variant| tuple_cardinality(crate_path, &variant.fields)); + .map(|variant| tuple_cardinality(options, &variant.fields)); quote! { #((#terms) +)* 0 } } -fn tuple_cardinality(crate_path: &Path, fields: &Fields) -> TokenStream { +fn tuple_cardinality(options: &DeriveOptions, fields: &Fields) -> TokenStream { + if options.use_default { + return quote! { 1 }; + } + + let crate_path = &options.crate_path; let factors = fields.iter().map(|field| { let ty = &field.ty; quote! { @@ -260,7 +275,7 @@ fn field_id(field: &Field, index: usize) -> Member { } fn init_value( - crate_path: &Path, + options: &DeriveOptions, ty: &Ident, variant: Option<&Ident>, fields: &Fields, @@ -273,22 +288,29 @@ fn init_value( } } else { let reset = direction.reset(); - let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len()); - let assignments = field_assignments(fields); - let bindings = bindings().take(fields.len()); - quote! {{ - match (#(#initialization,)*) { - (#(::core::option::Option::Some(#bindings),)*) => { - ::core::option::Option::Some(#id { #assignments }) + let crate_path = &options.crate_path; + + if options.use_default { + let assignments = default_field_assignments(fields); + quote! { ::core::option::Option::Some(#id { #assignments }) } + } else { + let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len()); + let assignments = field_assignments(fields); + let bindings = bindings().take(fields.len()); + quote! {{ + match (#(#initialization,)*) { + (#(::core::option::Option::Some(#bindings),)*) => { + ::core::option::Option::Some(#id { #assignments }) + } + _ => ::core::option::Option::None, } - _ => ::core::option::Option::None, - } - }} + }} + } } } fn next_variant( - crate_path: &Path, + options: &DeriveOptions, ty: &Ident, variants: &Punctuated, direction: Direction, @@ -306,7 +328,7 @@ fn next_variant( }; let arms = variants.iter().enumerate().map(|(i, v)| { let id = &v.ident; - let init = init_value(crate_path, ty, Some(id), &v.fields, direction); + let init = init_value(options, ty, Some(id), &v.fields, direction); quote! { #i => #init } @@ -342,7 +364,7 @@ fn advance_struct( } fn advance_enum( - crate_path: &Path, + options: &DeriveOptions, ty: &Ident, variants: &Punctuated, direction: Direction, @@ -351,13 +373,13 @@ fn advance_enum( Direction::Forward => variants .iter() .enumerate() - .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant)) + .map(|(i, variant)| advance_enum_arm(options, ty, direction, i, variant)) .collect(), Direction::Backward => variants .iter() .enumerate() .rev() - .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant)) + .map(|(i, variant)| advance_enum_arm(options, ty, direction, i, variant)) .collect(), }; quote! { @@ -368,7 +390,7 @@ fn advance_enum( } fn advance_enum_arm( - crate_path: &Path, + options: &DeriveOptions, ty: &Ident, direction: Direction, i: usize, @@ -391,17 +413,23 @@ fn advance_enum_arm( } } else { let destructuring = field_bindings(&variant.fields); - let assignments = field_assignments(&variant.fields); - let bindings = bindings().take(variant.fields.len()).collect::>(); - let tuple = advance_tuple(crate_path, &bindings, direction); - quote! { - #ty::#id { #destructuring } => { - let y = #tuple; - match y { - ::core::option::Option::Some((#(#bindings,)*)) => { - ::core::option::Option::Some(#ty::#id { #assignments }) + if options.use_default { + quote! { + #ty::#id { #destructuring } => #next + } + } else { + let assignments = field_assignments(&variant.fields); + let bindings = bindings().take(variant.fields.len()).collect::>(); + let tuple = advance_tuple(&options.crate_path, &bindings, direction); + quote! { + #ty::#id { #destructuring } => { + let y = #tuple; + match y { + ::core::option::Option::Some((#(#bindings,)*)) => { + ::core::option::Option::Some(#ty::#id { #assignments }) + } + ::core::option::Option::None => #next, } - ::core::option::Option::None => #next, } } } @@ -475,6 +503,31 @@ where .collect() } +fn default_field_assignments<'a, I>(fields: I) -> TokenStream +where + I: IntoIterator, +{ + fields + .into_iter() + .enumerate() + .map(|(i, field)| { + let field_id = field_id(field, i); + quote! { #field_id: ::core::default::Default::default(), } + }) + .collect() +} + +/// Creates a token stream for destructuring a tuple, given a sequence of fields. +/// +/// This function takes an iterator of fields, and returns a token stream +/// that can be used to destructure a tuple. The token stream will contain +/// assignments of the form `#field_id: ref #binding`, where `#field_id` +/// is the identifier of the field, and `#binding` is the identifier of the +/// binding. +/// +/// # Example +/// +/// fn field_bindings<'a, I>(fields: I) -> TokenStream where I: IntoIterator, @@ -510,6 +563,7 @@ where .map(|req| match req { TypeRequirement::Clone => clone_trait_path(), TypeRequirement::Sequence => trait_path(&crate_path), + TypeRequirement::Default => default_trait_path(), }) .map(trait_bound) .collect(), @@ -545,6 +599,19 @@ fn clone_trait_path() -> Path { } } +fn default_trait_path() -> Path { + Path { + leading_colon: Some(Default::default()), + segments: [ + PathSegment::from(Ident::new("core", Span::call_site())), + Ident::new("default", Span::call_site()).into(), + Ident::new("Default", Span::call_site()).into(), + ] + .into_iter() + .collect(), + } +} + fn tuple_type_requirements() -> impl Iterator { once([TypeRequirement::Sequence].into()).chain(repeat( [TypeRequirement::Sequence, TypeRequirement::Clone].into(), @@ -575,6 +642,7 @@ where enum TypeRequirement { Sequence, Clone, + Default, } #[derive(Clone, Debug, Default, PartialEq)] @@ -583,6 +651,7 @@ struct TypeRequirements(u8); impl TypeRequirements { const SEQUENCE: u8 = 0x1; const CLONE: u8 = 0x2; + const DEFAULT: u8 = 0x3; fn new() -> Self { Self::default() @@ -601,6 +670,9 @@ impl TypeRequirements { } else if n & Self::CLONE != 0 { n &= !Self::CLONE; Some(TypeRequirement::Clone) + } else if n & Self::DEFAULT != 0 { + n &= !Self::DEFAULT; + Some(TypeRequirement::Default) } else { None } @@ -615,6 +687,7 @@ impl TypeRequirements { match req { TypeRequirement::Sequence => Self::SEQUENCE, TypeRequirement::Clone => Self::CLONE, + TypeRequirement::Default => Self::DEFAULT, } } } @@ -656,6 +729,7 @@ impl Direction { #[derive(Debug)] enum Error { UnsupportedUnion, + DefaultForStruct, } impl Error { @@ -668,6 +742,9 @@ impl Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"), + Error::DefaultForStruct => { + f.write_str("Sequence over default values cannot be derived for struct types ") + } } } } diff --git a/enum-iterator/Cargo.toml b/enum-iterator/Cargo.toml index 2a0d33e..ba55b65 100644 --- a/enum-iterator/Cargo.toml +++ b/enum-iterator/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "enum-iterator" -version = "2.3.0" +version = "2.4.0" authors = ["Stephane Raux "] edition = "2021" description = "Tools to iterate over all values of a type (e.g. all variants of an enumeration)" diff --git a/enum-iterator/README.md b/enum-iterator/README.md index 54b0c51..d2e5310 100644 --- a/enum-iterator/README.md +++ b/enum-iterator/README.md @@ -53,6 +53,24 @@ assert_eq!(first::(), Some(Foo { a: false, b: 0 })); assert_eq!(last::(), Some(Foo { a: true, b: 255 })); ``` +```rust +use enum_iterator::{Sequence}; +use core::default; + +#[derive(Debug, PartialEq, Sequence)] +#[enum_iterator(use_default)] +enum Id { + Name { first: String, last: String }, + Empty, + Alias(String), +} + +let number = Id::Empty; + +assert_eq!(number.next(), Some(Id::Alias(default::Default()))); +assert_eq!(number.previous(), Some(Id::Name {first: default::Default(), last: default::Default()})); +``` + # Rust version This crate tracks stable Rust. Minor releases may require a newer Rust version. Patch releases must not require a newer Rust version. diff --git a/enum-iterator/src/lib.rs b/enum-iterator/src/lib.rs index 93d8287..a32b3a4 100644 --- a/enum-iterator/src/lib.rs +++ b/enum-iterator/src/lib.rs @@ -64,6 +64,17 @@ //! #[enum_iterator(crate = enum_iterator)] //! struct Foo; //! ``` +//! +//! # Iterate over enum's default variant +//! +//! For enums with only unit variants or variants with fields implementing [`core::default::Default`], +//! is is possible to iterate over them using the `use_default` parameter. +//! +//! ``` +//! #[derive(enum_iterator::Sequence)] +//! #[enum_iterator(use_default)] +//! enum Foo {}; +//! ``` //! //! # Rust version //! This crate tracks stable Rust. Minor releases may require a newer Rust version. Patch releases diff --git a/enum-iterator/tests/derive.rs b/enum-iterator/tests/derive.rs index 7b73c49..4933306 100644 --- a/enum-iterator/tests/derive.rs +++ b/enum-iterator/tests/derive.rs @@ -181,3 +181,73 @@ fn all_values_of_unit_are_yielded() { fn all_values_of_unit_are_yielded_in_reverse() { assert_eq!(reverse_all::().collect::>(), vec![Unit]); } + +#[derive(Debug, PartialEq, Sequence)] +#[enum_iterator(use_default)] +enum Id { + Number(u32), + Alias(String), + Name { first: String, last: String }, + Empty, +} + +#[test] +fn all_values_of_defaults_are_yielded() { + assert_eq!(cardinality::(), 4); + assert_eq!( + all::().collect::>(), + vec![ + Id::Number(Default::default()), + Id::Alias(Default::default()), + Id::Name { + first: Default::default(), + last: Default::default() + }, + Id::Empty, + ] + ); +} + +#[test] +fn iterate_over_defaults_reverse() { + let id = Id::Empty; + + let name = id.previous().unwrap(); + assert_eq!( + name, + Id::Name { + first: Default::default(), + last: Default::default(), + } + ); + + let alias = name.previous().unwrap(); + assert_eq!(alias, Id::Alias(Default::default())); + + let number = alias.previous().unwrap(); + assert_eq!(number, Id::Number(Default::default())); + + assert!(number.previous().is_none()); +} + +#[test] +fn iterate_over_defaults() { + let id = Id::Number(5); + + let alias = id.next().unwrap(); + assert_eq!(alias, Id::Alias(Default::default())); + + let name = alias.next().unwrap(); + assert_eq!( + name, + Id::Name { + first: Default::default(), + last: Default::default(), + } + ); + + let empty = name.next().unwrap(); + assert_eq!(empty, Id::Empty); + + assert!(empty.next().is_none()); +}