Skip to content

Commit da974a8

Browse files
committed
Add support for generic enums
1 parent 70406ae commit da974a8

File tree

4 files changed

+124
-16
lines changed

4 files changed

+124
-16
lines changed

macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ proc-macro = true
2020
[dependencies]
2121
proc-macro2 = { version = "1.0.81", features = ["span-locations"] }
2222
quote = "1.0.36"
23-
syn = "2.0.60"
23+
syn = { version = "2.0.60", features = ["derive", "visit"] }

macros/src/enum_deriver.rs

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use proc_macro2::TokenStream as TokenStream2;
22
use quote::quote;
3-
use syn::{parse_quote_spanned, DataEnum, DeriveInput, Fields, Type, Variant};
3+
use syn::{parse_quote_spanned, visit::Visit as _, DataEnum, DeriveInput, Fields, Type, Variant};
44

55
use crate::{
66
config_for_enum_with_attrs, config_for_variant, macro_name, position_of_selected_field,
7+
TypeVisitor,
78
};
89

910
pub(crate) struct EnumDeriver {
@@ -43,6 +44,8 @@ impl EnumDeriver {
4344
let outer = enum_ident;
4445
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
4546

47+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
48+
4649
let mut impls: Vec<TokenStream2> = vec![];
4750

4851
for variant in self.variants()? {
@@ -65,6 +68,10 @@ impl EnumDeriver {
6568
let inner_field = fields[selection_index];
6669
let inner_ty = &inner_field.ty;
6770

71+
if self.uses_generic_const_or_type(inner_ty) {
72+
continue;
73+
}
74+
6875
let field_expressions: Vec<_> = fields
6976
.iter()
7077
.enumerate()
@@ -92,7 +99,7 @@ impl EnumDeriver {
9299
};
93100

94101
impls.push(quote! {
95-
impl ::core::convert::From<#inner_ty> for #outer_ty {
102+
impl #impl_generics ::core::convert::From<#inner_ty> for #outer_ty #type_generics #where_clause {
96103
fn from(inner: #inner_ty) -> Self {
97104
#expression
98105
}
@@ -116,6 +123,8 @@ impl EnumDeriver {
116123
let outer = enum_ident;
117124
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
118125

126+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
127+
119128
let mut impls: Vec<TokenStream2> = vec![];
120129

121130
for variant in self.variants()? {
@@ -139,6 +148,10 @@ impl EnumDeriver {
139148
let inner_ident = inner_field.ident.as_ref();
140149
let inner_ty = &inner_field.ty;
141150

151+
if self.uses_generic_const_or_type(inner_ty) {
152+
continue;
153+
}
154+
142155
let pattern = match &variant.fields {
143156
Fields::Named(_) => {
144157
let field = inner_ident;
@@ -154,10 +167,10 @@ impl EnumDeriver {
154167
};
155168

156169
impls.push(quote! {
157-
impl ::core::convert::TryFrom<#outer_ty> for #inner_ty {
158-
type Error = #outer_ty;
170+
impl #impl_generics ::core::convert::TryFrom<#outer_ty #type_generics> for #inner_ty #where_clause {
171+
type Error = #outer_ty #type_generics;
159172

160-
fn try_from(outer: #outer_ty) -> Result<Self, Self::Error> {
173+
fn try_from(outer: #outer_ty #type_generics) -> Result<Self, Self::Error> {
161174
match outer {
162175
#pattern => Ok(inner),
163176
err => Err(err)
@@ -183,6 +196,8 @@ impl EnumDeriver {
183196
let outer = enum_ident;
184197
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
185198

199+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
200+
186201
let mut impls: Vec<TokenStream2> = vec![];
187202

188203
for variant in self.variants()? {
@@ -205,6 +220,10 @@ impl EnumDeriver {
205220
let inner_field = fields[selection_index];
206221
let inner_ty = &inner_field.ty;
207222

223+
if self.uses_generic_const_or_type(inner_ty) {
224+
continue;
225+
}
226+
208227
let field_expressions: Vec<_> = fields
209228
.iter()
210229
.enumerate()
@@ -232,7 +251,7 @@ impl EnumDeriver {
232251
};
233252

234253
impls.push(quote! {
235-
impl ::enumcapsulate::FromVariant<#inner_ty> for #outer_ty {
254+
impl #impl_generics ::enumcapsulate::FromVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
236255
fn from_variant(inner: #inner_ty) -> Self {
237256
#expression
238257
}
@@ -256,6 +275,8 @@ impl EnumDeriver {
256275
let outer = enum_ident;
257276
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
258277

278+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
279+
259280
let mut impls: Vec<TokenStream2> = vec![];
260281

261282
for variant in self.variants()? {
@@ -279,6 +300,10 @@ impl EnumDeriver {
279300
let inner_ident = inner_field.ident.as_ref();
280301
let inner_ty = &inner_field.ty;
281302

303+
if self.uses_generic_const_or_type(inner_ty) {
304+
continue;
305+
}
306+
282307
let pattern = match &variant.fields {
283308
Fields::Named(_) => {
284309
let field = inner_ident;
@@ -293,11 +318,13 @@ impl EnumDeriver {
293318
Fields::Unit => continue,
294319
};
295320

321+
let where_clause = match where_clause {
322+
Some(where_clause) => quote! { #where_clause #inner_ty: Clone },
323+
None => quote! { where #inner_ty: Clone },
324+
};
325+
296326
impls.push(quote! {
297-
impl ::enumcapsulate::AsVariant<#inner_ty> for #outer_ty
298-
where
299-
#inner_ty: Clone
300-
{
327+
impl #impl_generics ::enumcapsulate::AsVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
301328
fn as_variant(&self) -> Option<#inner_ty> {
302329
match self {
303330
#pattern => Some(inner.clone()),
@@ -324,6 +351,8 @@ impl EnumDeriver {
324351
let outer = enum_ident;
325352
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
326353

354+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
355+
327356
let mut impls: Vec<TokenStream2> = vec![];
328357

329358
for variant in self.variants()? {
@@ -347,6 +376,9 @@ impl EnumDeriver {
347376
let inner_ident = inner_field.ident.as_ref();
348377
let inner_ty = &inner_field.ty;
349378

379+
if self.uses_generic_const_or_type(inner_ty) {
380+
continue;
381+
}
350382
let pattern = match &variant.fields {
351383
Fields::Named(_) => {
352384
let field = inner_ident;
@@ -362,7 +394,7 @@ impl EnumDeriver {
362394
};
363395

364396
impls.push(quote! {
365-
impl ::enumcapsulate::AsVariantRef<#inner_ty> for #outer_ty {
397+
impl #impl_generics ::enumcapsulate::AsVariantRef<#inner_ty> for #outer_ty #type_generics #where_clause {
366398
fn as_variant_ref(&self) -> Option<&#inner_ty> {
367399
match self {
368400
#pattern => Some(inner),
@@ -389,6 +421,8 @@ impl EnumDeriver {
389421
let outer = enum_ident;
390422
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
391423

424+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
425+
392426
let mut impls: Vec<TokenStream2> = vec![];
393427

394428
for variant in self.variants()? {
@@ -412,6 +446,10 @@ impl EnumDeriver {
412446
let inner_ident = inner_field.ident.as_ref();
413447
let inner_ty = &inner_field.ty;
414448

449+
if self.uses_generic_const_or_type(inner_ty) {
450+
continue;
451+
}
452+
415453
let pattern = match &variant.fields {
416454
Fields::Named(_) => {
417455
let field = inner_ident;
@@ -427,7 +465,7 @@ impl EnumDeriver {
427465
};
428466

429467
impls.push(quote! {
430-
impl ::enumcapsulate::AsVariantMut<#inner_ty> for #outer_ty {
468+
impl #impl_generics ::enumcapsulate::AsVariantMut<#inner_ty> for #outer_ty #type_generics #where_clause {
431469
fn as_variant_mut(&mut self) -> Option<&mut #inner_ty> {
432470
match self {
433471
#pattern => Some(inner),
@@ -454,6 +492,8 @@ impl EnumDeriver {
454492
let outer = enum_ident;
455493
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
456494

495+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
496+
457497
let mut impls: Vec<TokenStream2> = vec![];
458498

459499
for variant in self.variants()? {
@@ -477,6 +517,10 @@ impl EnumDeriver {
477517
let inner_ident = inner_field.ident.as_ref();
478518
let inner_ty = &inner_field.ty;
479519

520+
if self.uses_generic_const_or_type(inner_ty) {
521+
continue;
522+
}
523+
480524
let pattern = match &variant.fields {
481525
Fields::Named(_) => {
482526
let field = inner_ident;
@@ -492,7 +536,7 @@ impl EnumDeriver {
492536
};
493537

494538
impls.push(quote! {
495-
impl ::enumcapsulate::IntoVariant<#inner_ty> for #outer_ty {
539+
impl #impl_generics ::enumcapsulate::IntoVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
496540
fn into_variant(self) -> Result<#inner_ty, Self> {
497541
match self {
498542
#pattern => Ok(inner),
@@ -523,8 +567,10 @@ impl EnumDeriver {
523567
let outer = enum_ident;
524568
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
525569

570+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
571+
526572
let tokens = quote! {
527-
impl ::enumcapsulate::VariantDowncast for #outer_ty {}
573+
impl #impl_generics ::enumcapsulate::VariantDowncast for #outer_ty #type_generics #where_clause {}
528574
};
529575

530576
Ok(tokens)
@@ -545,6 +591,8 @@ impl EnumDeriver {
545591
let outer = enum_ident;
546592
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };
547593

594+
let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();
595+
548596
let variants = self.variants()?;
549597

550598
let discriminant_ident = quote::format_ident!("{outer}Discriminant");
@@ -592,7 +640,7 @@ impl EnumDeriver {
592640
Ok(quote! {
593641
#discriminant_enum
594642

595-
impl ::enumcapsulate::VariantDiscriminant for #outer_ty {
643+
impl #impl_generics ::enumcapsulate::VariantDiscriminant for #outer_ty #type_generics #where_clause {
596644
type Discriminant = #discriminant_ident;
597645

598646
fn variant_discriminant(&self) -> Self::Discriminant {
@@ -604,4 +652,12 @@ impl EnumDeriver {
604652
}
605653
})
606654
}
655+
656+
fn uses_generic_const_or_type(&self, ty: &syn::Type) -> bool {
657+
let mut visitor = TypeVisitor::new(&self.input.generics);
658+
659+
visitor.visit_type(ty);
660+
661+
visitor.uses_const_or_type_param()
662+
}
607663
}

macros/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ use crate::utils::tokenstream;
55

66
mod config;
77
mod enum_deriver;
8+
mod type_visitor;
89
mod utils;
910

1011
use config::*;
1112
use enum_deriver::*;
13+
use type_visitor::*;
1214
use utils::*;
1315

1416
/// Derive macro generating an impl of the trait `From<T>`.

macros/src/type_visitor.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
use std::collections::HashSet;
2+
3+
use syn::visit::Visit;
4+
5+
pub struct TypeVisitor<'ast> {
6+
const_param_idents: HashSet<&'ast syn::Ident>,
7+
type_param_idents: HashSet<&'ast syn::Ident>,
8+
9+
uses_const_param: bool,
10+
uses_type_param: bool,
11+
}
12+
13+
impl<'ast> TypeVisitor<'ast> {
14+
pub fn new(generics: &'ast syn::Generics) -> Self {
15+
Self {
16+
const_param_idents: generics.const_params().map(|param| &param.ident).collect(),
17+
type_param_idents: generics.type_params().map(|param| &param.ident).collect(),
18+
uses_const_param: false,
19+
uses_type_param: false,
20+
}
21+
}
22+
23+
pub fn uses_const_or_type_param(self) -> bool {
24+
self.uses_const_param || self.uses_type_param
25+
}
26+
}
27+
28+
impl<'ast> Visit<'ast> for TypeVisitor<'ast> {
29+
fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
30+
if node.qself.is_none() {
31+
let path_segments = &node.path.segments;
32+
let first_capitalized_segment = path_segments.iter().find(|&segment| {
33+
let ident_name = segment.ident.to_string();
34+
let first_char: char = ident_name.chars().next().unwrap();
35+
first_char.is_uppercase()
36+
});
37+
38+
if let Some(path_segment) = first_capitalized_segment {
39+
let ident = &path_segment.ident;
40+
41+
if self.type_param_idents.contains(ident) {
42+
self.uses_type_param = true;
43+
} else if self.const_param_idents.contains(ident) {
44+
self.uses_const_param = true;
45+
}
46+
}
47+
}
48+
syn::visit::visit_type_path(self, node);
49+
}
50+
}

0 commit comments

Comments
 (0)