Skip to content

Commit 2c72044

Browse files
authored
Merge pull request #373 from Areredify/visitor
add `Visitor` and friends
2 parents fbe7874 + 88a6bbb commit 2c72044

File tree

10 files changed

+1331
-199
lines changed

10 files changed

+1331
-199
lines changed

chalk-derive/src/lib.rs

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
extern crate proc_macro;
22

33
use proc_macro::TokenStream;
4+
use proc_macro2::Span;
45
use quote::{format_ident, quote};
56
use syn::{parse_macro_input, Data, DeriveInput, GenericParam, Ident, TypeParamBound};
67

@@ -339,3 +340,219 @@ fn bounded_by_trait<'p>(param: &'p GenericParam, name: &str) -> Option<&'p Ident
339340
_ => None,
340341
}
341342
}
343+
344+
/// Derives Visit for structs and enums for which one of the following is true:
345+
/// - It has a `#[has_interner(TheInterner)]` attribute
346+
/// - There is a single parameter `T: HasInterner` (does not have to be named `T`)
347+
/// - There is a single parameter `I: Interner` (does not have to be named `I`)
348+
#[proc_macro_derive(Visit, attributes(has_interner))]
349+
pub fn derive_visit(item: TokenStream) -> TokenStream {
350+
let trait_name = Ident::new("Visit", Span::call_site());
351+
let method_name = Ident::new("visit_with", Span::call_site());
352+
derive_any_visit(item, trait_name, method_name)
353+
}
354+
355+
/// Same as Visit, but derives SuperVisit instead
356+
#[proc_macro_derive(SuperVisit, attributes(has_interner))]
357+
pub fn derive_super_visit(item: TokenStream) -> TokenStream {
358+
let trait_name = Ident::new("SuperVisit", Span::call_site());
359+
let method_name = Ident::new("super_visit_with", Span::call_site());
360+
derive_any_visit(item, trait_name, method_name)
361+
}
362+
363+
fn derive_any_visit(item: TokenStream, trait_name: Ident, method_name: Ident) -> TokenStream {
364+
let input = parse_macro_input!(item as DeriveInput);
365+
let (impl_generics, ty_generics, where_clause_ref) = input.generics.split_for_impl();
366+
367+
let type_name = input.ident;
368+
let body = derive_visit_body(&type_name, input.data);
369+
370+
if let Some(attr) = input.attrs.iter().find(|a| a.path.is_ident("has_interner")) {
371+
// Hardcoded interner:
372+
//
373+
// impl Visit<ChalkIr> for Type {
374+
//
375+
// }
376+
let arg = attr
377+
.parse_args::<proc_macro2::TokenStream>()
378+
.expect("Expected has_interner argument");
379+
380+
return TokenStream::from(quote! {
381+
impl #impl_generics #trait_name < #arg > for #type_name #ty_generics #where_clause_ref {
382+
fn #method_name <'i, R: VisitResult>(
383+
&self,
384+
visitor: &mut dyn Visitor < 'i, #arg, Result = R >,
385+
outer_binder: DebruijnIndex,
386+
) -> R
387+
where
388+
I: 'i
389+
{
390+
#body
391+
}
392+
}
393+
});
394+
}
395+
396+
match input.generics.params.len() {
397+
1 => {}
398+
399+
0 => {
400+
panic!("Visit derive requires a single type parameter or a `#[has_interner]` attr");
401+
}
402+
403+
_ => {
404+
panic!("Visit derive only works with a single type parameter");
405+
}
406+
};
407+
408+
let generic_param0 = &input.generics.params[0];
409+
410+
if let Some(param) = has_interner(&generic_param0) {
411+
// HasInterner bound:
412+
//
413+
// Example:
414+
//
415+
// impl<T, _I> Visit<_I> for Binders<T>
416+
// where
417+
// T: HasInterner<Interner = _I>,
418+
// {
419+
// }
420+
421+
let mut impl_generics = input.generics.clone();
422+
impl_generics.params.extend(vec![GenericParam::Type(
423+
syn::parse(quote! { _I: Interner }.into()).unwrap(),
424+
)]);
425+
426+
let mut where_clause = where_clause_ref
427+
.cloned()
428+
.unwrap_or_else(|| syn::parse2(quote![where]).unwrap());
429+
where_clause
430+
.predicates
431+
.push(syn::parse2(quote! { #param: HasInterner<Interner = _I> }).unwrap());
432+
where_clause
433+
.predicates
434+
.push(syn::parse2(quote! { #param: Visit<_I> }).unwrap());
435+
436+
return TokenStream::from(quote! {
437+
impl #impl_generics #trait_name < _I > for #type_name < #param >
438+
#where_clause
439+
{
440+
fn #method_name <'i, R: VisitResult>(
441+
&self,
442+
visitor: &mut dyn Visitor < 'i, _I, Result = R >,
443+
outer_binder: DebruijnIndex,
444+
) -> R
445+
where
446+
_I: 'i
447+
{
448+
#body
449+
}
450+
}
451+
});
452+
}
453+
454+
// Interner bound:
455+
//
456+
// Example:
457+
//
458+
// impl<I> Visit<I> for Foo<I>
459+
// where
460+
// I: Interner,
461+
// {
462+
// }
463+
464+
if let Some(i) = is_interner(&generic_param0) {
465+
let impl_generics = &input.generics;
466+
467+
return TokenStream::from(quote! {
468+
impl #impl_generics #trait_name < #i > for #type_name < #i >
469+
#where_clause_ref
470+
{
471+
fn #method_name <'i, R: VisitResult>(
472+
&self,
473+
visitor: &mut dyn Visitor < 'i, #i, Result = R >,
474+
outer_binder: DebruijnIndex,
475+
) -> R
476+
where
477+
I: 'i
478+
{
479+
#body
480+
}
481+
}
482+
});
483+
}
484+
485+
panic!(
486+
"derive({}) requires a parameter that implements HasInterner or Interner",
487+
trait_name
488+
);
489+
}
490+
491+
/// Generates the body of the Visit impl
492+
fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream {
493+
match data {
494+
Data::Struct(s) => {
495+
let fields = s.fields.into_iter().map(|f| {
496+
let name = f.ident.as_ref().expect("Unnamed field in a struct");
497+
quote! {
498+
result = result.combine(self.#name.visit_with(visitor, outer_binder));
499+
if result.return_early() { return result; }
500+
}
501+
});
502+
quote! {
503+
let mut result = R::new();
504+
#(#fields)*
505+
506+
result
507+
}
508+
}
509+
Data::Enum(e) => {
510+
let matches = e.variants.into_iter().map(|v| {
511+
let variant = v.ident;
512+
match &v.fields {
513+
syn::Fields::Named(fields) => {
514+
let fnames: &Vec<_> = &fields.named.iter().map(|f| &f.ident).collect();
515+
quote! {
516+
#type_name :: #variant { #(#fnames),* } => {
517+
let mut result = R::new();
518+
#(
519+
result = result.combine(#fnames.visit_with(visitor, outer_binder));
520+
if result.return_early() { return result; }
521+
)*
522+
result
523+
}
524+
}
525+
}
526+
527+
syn::Fields::Unnamed(_fields) => {
528+
let names: Vec<_> = (0..v.fields.iter().count())
529+
.map(|index| format_ident!("a{}", index))
530+
.collect();
531+
quote! {
532+
#type_name::#variant( #(ref #names),* ) => {
533+
let mut result = R::new();
534+
#(
535+
result = result.combine(#names.visit_with(visitor, outer_binder));
536+
if result.return_early() { return result; }
537+
)*
538+
result
539+
}
540+
}
541+
}
542+
543+
syn::Fields::Unit => {
544+
quote! {
545+
#type_name::#variant => R::new(),
546+
}
547+
}
548+
}
549+
});
550+
quote! {
551+
match *self {
552+
#(#matches)*
553+
}
554+
}
555+
}
556+
Data::Union(..) => panic!("Visit can not be derived for unions"),
557+
}
558+
}

0 commit comments

Comments
 (0)