Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enum-iterator-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "enum-iterator-derive"
version = "1.5.0"
version = "1.6.0"
authors = ["Stephane Raux <[email protected]>"]
edition = "2021"
description = "Procedural macro to derive Sequence"
Expand Down
157 changes: 117 additions & 40 deletions enum-iterator-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
#[derive(Debug)]
struct DeriveOptions {
crate_path: Path,
use_default: bool,
}

impl DeriveOptions {
fn parse(attrs: &[syn::Attribute]) -> Result<Self, syn::Error> {
let mut crate_path = None;
let mut use_default = false;
attrs
.iter()
.filter(|attr| attr.path().is_ident("enum_iterator"))
Expand All @@ -63,6 +65,9 @@ impl DeriveOptions {
} else {
Err(meta.error("duplicate crate key"))
}
} else if meta.path.is_ident("use_default") {
use_default = true;
Ok(())
} else {
Err(meta.error(format!("unknown key {}", meta.path.to_token_stream())))
}
Expand All @@ -78,6 +83,7 @@ impl DeriveOptions {
.into_iter()
.collect(),
}),
use_default,
})
}
}
Expand All @@ -87,7 +93,12 @@ fn derive_for_ast(ast: DeriveInput) -> Result<TokenStream, syn::Error> {
let generics = &ast.generics;
let options = DeriveOptions::parse(&ast.attrs)?;
match &ast.data {
syn::Data::Struct(s) => derive_for_struct(&options, ty, generics, &s.fields),
syn::Data::Struct(s) => {
if options.use_default {
return Err(Error::DefaultForStruct.with_tokens(ast));
}
derive_for_struct(&options, ty, generics, &s.fields)
}
syn::Data::Enum(e) => derive_for_enum(&options, ty, generics, &e.variants),
syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)),
}
Expand All @@ -100,9 +111,9 @@ fn derive_for_struct(
fields: &Fields,
) -> Result<TokenStream, syn::Error> {
let crate_path = &options.crate_path;
let cardinality = tuple_cardinality(&options.crate_path, fields);
let first = init_value(&options.crate_path, ty, None, fields, Direction::Forward);
let last = init_value(&options.crate_path, ty, None, fields, Direction::Backward);
let cardinality = tuple_cardinality(options, fields);
let first = init_value(options, ty, None, fields, Direction::Forward);
let last = init_value(options, ty, None, fields, Direction::Backward);
let next_body = advance_struct(&options.crate_path, ty, fields, Direction::Forward);
let previous_body = advance_struct(&options.crate_path, ty, fields, Direction::Backward);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
Expand Down Expand Up @@ -153,9 +164,9 @@ fn derive_for_enum(
generics: &Generics,
variants: &Punctuated<Variant, Comma>,
) -> Result<TokenStream, syn::Error> {
let cardinality = enum_cardinality(&options.crate_path, variants);
let next_body = advance_enum(&options.crate_path, ty, variants, Direction::Forward);
let previous_body = advance_enum(&options.crate_path, ty, variants, Direction::Backward);
let cardinality = enum_cardinality(options, variants);
let next_body = advance_enum(options, ty, variants, Direction::Forward);
let previous_body = advance_enum(options, ty, variants, Direction::Backward);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let where_clause = if generics.params.is_empty() {
where_clause.cloned()
Expand All @@ -175,9 +186,8 @@ fn derive_for_enum(
);
Some(clause)
};
let next_variant_body = next_variant(&options.crate_path, ty, variants, Direction::Forward);
let previous_variant_body =
next_variant(&options.crate_path, ty, variants, Direction::Backward);
let next_variant_body = next_variant(options, ty, variants, Direction::Forward);
let previous_variant_body = next_variant(options, ty, variants, Direction::Backward);
let (first, last) = if variants.is_empty() {
(
quote! { ::core::option::Option::None },
Expand Down Expand Up @@ -231,16 +241,21 @@ fn derive_for_enum(
Ok(tokens)
}

fn enum_cardinality(crate_path: &Path, variants: &Punctuated<Variant, Comma>) -> TokenStream {
fn enum_cardinality(options: &DeriveOptions, variants: &Punctuated<Variant, Comma>) -> TokenStream {
let terms = variants
.iter()
.map(|variant| tuple_cardinality(crate_path, &variant.fields));
.map(|variant| tuple_cardinality(options, &variant.fields));
quote! {
#((#terms) +)* 0
}
}

fn tuple_cardinality(crate_path: &Path, fields: &Fields) -> TokenStream {
fn tuple_cardinality(options: &DeriveOptions, fields: &Fields) -> TokenStream {
if options.use_default {
return quote! { 1 };
}

let crate_path = &options.crate_path;
let factors = fields.iter().map(|field| {
let ty = &field.ty;
quote! {
Expand All @@ -260,7 +275,7 @@ fn field_id(field: &Field, index: usize) -> Member {
}

fn init_value(
crate_path: &Path,
options: &DeriveOptions,
ty: &Ident,
variant: Option<&Ident>,
fields: &Fields,
Expand All @@ -273,22 +288,29 @@ fn init_value(
}
} else {
let reset = direction.reset();
let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len());
let assignments = field_assignments(fields);
let bindings = bindings().take(fields.len());
quote! {{
match (#(#initialization,)*) {
(#(::core::option::Option::Some(#bindings),)*) => {
::core::option::Option::Some(#id { #assignments })
let crate_path = &options.crate_path;

if options.use_default {
let assignments = default_field_assignments(fields);
quote! { ::core::option::Option::Some(#id { #assignments }) }
} else {
let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len());
let assignments = field_assignments(fields);
let bindings = bindings().take(fields.len());
quote! {{
match (#(#initialization,)*) {
(#(::core::option::Option::Some(#bindings),)*) => {
::core::option::Option::Some(#id { #assignments })
}
_ => ::core::option::Option::None,
}
_ => ::core::option::Option::None,
}
}}
}}
}
}
}

fn next_variant(
crate_path: &Path,
options: &DeriveOptions,
ty: &Ident,
variants: &Punctuated<Variant, Comma>,
direction: Direction,
Expand All @@ -306,7 +328,7 @@ fn next_variant(
};
let arms = variants.iter().enumerate().map(|(i, v)| {
let id = &v.ident;
let init = init_value(crate_path, ty, Some(id), &v.fields, direction);
let init = init_value(options, ty, Some(id), &v.fields, direction);
quote! {
#i => #init
}
Expand Down Expand Up @@ -342,7 +364,7 @@ fn advance_struct(
}

fn advance_enum(
crate_path: &Path,
options: &DeriveOptions,
ty: &Ident,
variants: &Punctuated<Variant, Comma>,
direction: Direction,
Expand All @@ -351,13 +373,13 @@ fn advance_enum(
Direction::Forward => variants
.iter()
.enumerate()
.map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
.map(|(i, variant)| advance_enum_arm(options, ty, direction, i, variant))
.collect(),
Direction::Backward => variants
.iter()
.enumerate()
.rev()
.map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
.map(|(i, variant)| advance_enum_arm(options, ty, direction, i, variant))
.collect(),
};
quote! {
Expand All @@ -368,7 +390,7 @@ fn advance_enum(
}

fn advance_enum_arm(
crate_path: &Path,
options: &DeriveOptions,
ty: &Ident,
direction: Direction,
i: usize,
Expand All @@ -391,17 +413,23 @@ fn advance_enum_arm(
}
} else {
let destructuring = field_bindings(&variant.fields);
let assignments = field_assignments(&variant.fields);
let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
let tuple = advance_tuple(crate_path, &bindings, direction);
quote! {
#ty::#id { #destructuring } => {
let y = #tuple;
match y {
::core::option::Option::Some((#(#bindings,)*)) => {
::core::option::Option::Some(#ty::#id { #assignments })
if options.use_default {
quote! {
#ty::#id { #destructuring } => #next
}
} else {
let assignments = field_assignments(&variant.fields);
let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
let tuple = advance_tuple(&options.crate_path, &bindings, direction);
quote! {
#ty::#id { #destructuring } => {
let y = #tuple;
match y {
::core::option::Option::Some((#(#bindings,)*)) => {
::core::option::Option::Some(#ty::#id { #assignments })
}
::core::option::Option::None => #next,
}
::core::option::Option::None => #next,
}
}
}
Expand Down Expand Up @@ -475,6 +503,31 @@ where
.collect()
}

fn default_field_assignments<'a, I>(fields: I) -> TokenStream
where
I: IntoIterator<Item = &'a Field>,
{
fields
.into_iter()
.enumerate()
.map(|(i, field)| {
let field_id = field_id(field, i);
quote! { #field_id: ::core::default::Default::default(), }
})
.collect()
}

/// Creates a token stream for destructuring a tuple, given a sequence of fields.
///
/// This function takes an iterator of fields, and returns a token stream
/// that can be used to destructure a tuple. The token stream will contain
/// assignments of the form `#field_id: ref #binding`, where `#field_id`
/// is the identifier of the field, and `#binding` is the identifier of the
/// binding.
///
/// # Example
///
///
fn field_bindings<'a, I>(fields: I) -> TokenStream
where
I: IntoIterator<Item = &'a Field>,
Expand Down Expand Up @@ -510,6 +563,7 @@ where
.map(|req| match req {
TypeRequirement::Clone => clone_trait_path(),
TypeRequirement::Sequence => trait_path(&crate_path),
TypeRequirement::Default => default_trait_path(),
})
.map(trait_bound)
.collect(),
Expand Down Expand Up @@ -545,6 +599,19 @@ fn clone_trait_path() -> Path {
}
}

fn default_trait_path() -> Path {
Path {
leading_colon: Some(Default::default()),
segments: [
PathSegment::from(Ident::new("core", Span::call_site())),
Ident::new("default", Span::call_site()).into(),
Ident::new("Default", Span::call_site()).into(),
]
.into_iter()
.collect(),
}
}

fn tuple_type_requirements() -> impl Iterator<Item = TypeRequirements> {
once([TypeRequirement::Sequence].into()).chain(repeat(
[TypeRequirement::Sequence, TypeRequirement::Clone].into(),
Expand Down Expand Up @@ -575,6 +642,7 @@ where
enum TypeRequirement {
Sequence,
Clone,
Default,
}

#[derive(Clone, Debug, Default, PartialEq)]
Expand All @@ -583,6 +651,7 @@ struct TypeRequirements(u8);
impl TypeRequirements {
const SEQUENCE: u8 = 0x1;
const CLONE: u8 = 0x2;
const DEFAULT: u8 = 0x3;

fn new() -> Self {
Self::default()
Expand All @@ -601,6 +670,9 @@ impl TypeRequirements {
} else if n & Self::CLONE != 0 {
n &= !Self::CLONE;
Some(TypeRequirement::Clone)
} else if n & Self::DEFAULT != 0 {
n &= !Self::DEFAULT;
Some(TypeRequirement::Default)
} else {
None
}
Expand All @@ -615,6 +687,7 @@ impl TypeRequirements {
match req {
TypeRequirement::Sequence => Self::SEQUENCE,
TypeRequirement::Clone => Self::CLONE,
TypeRequirement::Default => Self::DEFAULT,
}
}
}
Expand Down Expand Up @@ -656,6 +729,7 @@ impl Direction {
#[derive(Debug)]
enum Error {
UnsupportedUnion,
DefaultForStruct,
}

impl Error {
Expand All @@ -668,6 +742,9 @@ impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"),
Error::DefaultForStruct => {
f.write_str("Sequence over default values cannot be derived for struct types ")
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion enum-iterator/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "enum-iterator"
version = "2.3.0"
version = "2.4.0"
authors = ["Stephane Raux <[email protected]>"]
edition = "2021"
description = "Tools to iterate over all values of a type (e.g. all variants of an enumeration)"
Expand Down
18 changes: 18 additions & 0 deletions enum-iterator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ assert_eq!(first::<Foo>(), Some(Foo { a: false, b: 0 }));
assert_eq!(last::<Foo>(), Some(Foo { a: true, b: 255 }));
```

```rust
use enum_iterator::{Sequence};
use core::default;

#[derive(Debug, PartialEq, Sequence)]
#[enum_iterator(use_default)]
enum Id {
Name { first: String, last: String },
Empty,
Alias(String),
}

let number = Id::Empty;

assert_eq!(number.next(), Some(Id::Alias(default::Default())));
assert_eq!(number.previous(), Some(Id::Name {first: default::Default(), last: default::Default()}));
```

# Rust version
This crate tracks stable Rust. Minor releases may require a newer Rust version. Patch releases
must not require a newer Rust version.
Expand Down
Loading