|
1 | 1 | extern crate proc_macro;
|
2 | 2 |
|
3 | 3 | use proc_macro::TokenStream;
|
| 4 | +use proc_macro2::Span; |
4 | 5 | use quote::{format_ident, quote};
|
5 | 6 | use syn::{parse_macro_input, Data, DeriveInput, GenericParam, Ident, TypeParamBound};
|
6 | 7 |
|
@@ -339,3 +340,219 @@ fn bounded_by_trait<'p>(param: &'p GenericParam, name: &str) -> Option<&'p Ident
|
339 | 340 | _ => None,
|
340 | 341 | }
|
341 | 342 | }
|
| 343 | + |
| 344 | +/// Derives Visit for structs and enums for which one of the following is true: |
| 345 | +/// - It has a `#[has_interner(TheInterner)]` attribute |
| 346 | +/// - There is a single parameter `T: HasInterner` (does not have to be named `T`) |
| 347 | +/// - There is a single parameter `I: Interner` (does not have to be named `I`) |
| 348 | +#[proc_macro_derive(Visit, attributes(has_interner))] |
| 349 | +pub fn derive_visit(item: TokenStream) -> TokenStream { |
| 350 | + let trait_name = Ident::new("Visit", Span::call_site()); |
| 351 | + let method_name = Ident::new("visit_with", Span::call_site()); |
| 352 | + derive_any_visit(item, trait_name, method_name) |
| 353 | +} |
| 354 | + |
| 355 | +/// Same as Visit, but derives SuperVisit instead |
| 356 | +#[proc_macro_derive(SuperVisit, attributes(has_interner))] |
| 357 | +pub fn derive_super_visit(item: TokenStream) -> TokenStream { |
| 358 | + let trait_name = Ident::new("SuperVisit", Span::call_site()); |
| 359 | + let method_name = Ident::new("super_visit_with", Span::call_site()); |
| 360 | + derive_any_visit(item, trait_name, method_name) |
| 361 | +} |
| 362 | + |
| 363 | +fn derive_any_visit(item: TokenStream, trait_name: Ident, method_name: Ident) -> TokenStream { |
| 364 | + let input = parse_macro_input!(item as DeriveInput); |
| 365 | + let (impl_generics, ty_generics, where_clause_ref) = input.generics.split_for_impl(); |
| 366 | + |
| 367 | + let type_name = input.ident; |
| 368 | + let body = derive_visit_body(&type_name, input.data); |
| 369 | + |
| 370 | + if let Some(attr) = input.attrs.iter().find(|a| a.path.is_ident("has_interner")) { |
| 371 | + // Hardcoded interner: |
| 372 | + // |
| 373 | + // impl Visit<ChalkIr> for Type { |
| 374 | + // |
| 375 | + // } |
| 376 | + let arg = attr |
| 377 | + .parse_args::<proc_macro2::TokenStream>() |
| 378 | + .expect("Expected has_interner argument"); |
| 379 | + |
| 380 | + return TokenStream::from(quote! { |
| 381 | + impl #impl_generics #trait_name < #arg > for #type_name #ty_generics #where_clause_ref { |
| 382 | + fn #method_name <'i, R: VisitResult>( |
| 383 | + &self, |
| 384 | + visitor: &mut dyn Visitor < 'i, #arg, Result = R >, |
| 385 | + outer_binder: DebruijnIndex, |
| 386 | + ) -> R |
| 387 | + where |
| 388 | + I: 'i |
| 389 | + { |
| 390 | + #body |
| 391 | + } |
| 392 | + } |
| 393 | + }); |
| 394 | + } |
| 395 | + |
| 396 | + match input.generics.params.len() { |
| 397 | + 1 => {} |
| 398 | + |
| 399 | + 0 => { |
| 400 | + panic!("Visit derive requires a single type parameter or a `#[has_interner]` attr"); |
| 401 | + } |
| 402 | + |
| 403 | + _ => { |
| 404 | + panic!("Visit derive only works with a single type parameter"); |
| 405 | + } |
| 406 | + }; |
| 407 | + |
| 408 | + let generic_param0 = &input.generics.params[0]; |
| 409 | + |
| 410 | + if let Some(param) = has_interner(&generic_param0) { |
| 411 | + // HasInterner bound: |
| 412 | + // |
| 413 | + // Example: |
| 414 | + // |
| 415 | + // impl<T, _I> Visit<_I> for Binders<T> |
| 416 | + // where |
| 417 | + // T: HasInterner<Interner = _I>, |
| 418 | + // { |
| 419 | + // } |
| 420 | + |
| 421 | + let mut impl_generics = input.generics.clone(); |
| 422 | + impl_generics.params.extend(vec![GenericParam::Type( |
| 423 | + syn::parse(quote! { _I: Interner }.into()).unwrap(), |
| 424 | + )]); |
| 425 | + |
| 426 | + let mut where_clause = where_clause_ref |
| 427 | + .cloned() |
| 428 | + .unwrap_or_else(|| syn::parse2(quote![where]).unwrap()); |
| 429 | + where_clause |
| 430 | + .predicates |
| 431 | + .push(syn::parse2(quote! { #param: HasInterner<Interner = _I> }).unwrap()); |
| 432 | + where_clause |
| 433 | + .predicates |
| 434 | + .push(syn::parse2(quote! { #param: Visit<_I> }).unwrap()); |
| 435 | + |
| 436 | + return TokenStream::from(quote! { |
| 437 | + impl #impl_generics #trait_name < _I > for #type_name < #param > |
| 438 | + #where_clause |
| 439 | + { |
| 440 | + fn #method_name <'i, R: VisitResult>( |
| 441 | + &self, |
| 442 | + visitor: &mut dyn Visitor < 'i, _I, Result = R >, |
| 443 | + outer_binder: DebruijnIndex, |
| 444 | + ) -> R |
| 445 | + where |
| 446 | + _I: 'i |
| 447 | + { |
| 448 | + #body |
| 449 | + } |
| 450 | + } |
| 451 | + }); |
| 452 | + } |
| 453 | + |
| 454 | + // Interner bound: |
| 455 | + // |
| 456 | + // Example: |
| 457 | + // |
| 458 | + // impl<I> Visit<I> for Foo<I> |
| 459 | + // where |
| 460 | + // I: Interner, |
| 461 | + // { |
| 462 | + // } |
| 463 | + |
| 464 | + if let Some(i) = is_interner(&generic_param0) { |
| 465 | + let impl_generics = &input.generics; |
| 466 | + |
| 467 | + return TokenStream::from(quote! { |
| 468 | + impl #impl_generics #trait_name < #i > for #type_name < #i > |
| 469 | + #where_clause_ref |
| 470 | + { |
| 471 | + fn #method_name <'i, R: VisitResult>( |
| 472 | + &self, |
| 473 | + visitor: &mut dyn Visitor < 'i, #i, Result = R >, |
| 474 | + outer_binder: DebruijnIndex, |
| 475 | + ) -> R |
| 476 | + where |
| 477 | + I: 'i |
| 478 | + { |
| 479 | + #body |
| 480 | + } |
| 481 | + } |
| 482 | + }); |
| 483 | + } |
| 484 | + |
| 485 | + panic!( |
| 486 | + "derive({}) requires a parameter that implements HasInterner or Interner", |
| 487 | + trait_name |
| 488 | + ); |
| 489 | +} |
| 490 | + |
| 491 | +/// Generates the body of the Visit impl |
| 492 | +fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream { |
| 493 | + match data { |
| 494 | + Data::Struct(s) => { |
| 495 | + let fields = s.fields.into_iter().map(|f| { |
| 496 | + let name = f.ident.as_ref().expect("Unnamed field in a struct"); |
| 497 | + quote! { |
| 498 | + result = result.combine(self.#name.visit_with(visitor, outer_binder)); |
| 499 | + if result.return_early() { return result; } |
| 500 | + } |
| 501 | + }); |
| 502 | + quote! { |
| 503 | + let mut result = R::new(); |
| 504 | + #(#fields)* |
| 505 | + |
| 506 | + result |
| 507 | + } |
| 508 | + } |
| 509 | + Data::Enum(e) => { |
| 510 | + let matches = e.variants.into_iter().map(|v| { |
| 511 | + let variant = v.ident; |
| 512 | + match &v.fields { |
| 513 | + syn::Fields::Named(fields) => { |
| 514 | + let fnames: &Vec<_> = &fields.named.iter().map(|f| &f.ident).collect(); |
| 515 | + quote! { |
| 516 | + #type_name :: #variant { #(#fnames),* } => { |
| 517 | + let mut result = R::new(); |
| 518 | + #( |
| 519 | + result = result.combine(#fnames.visit_with(visitor, outer_binder)); |
| 520 | + if result.return_early() { return result; } |
| 521 | + )* |
| 522 | + result |
| 523 | + } |
| 524 | + } |
| 525 | + } |
| 526 | + |
| 527 | + syn::Fields::Unnamed(_fields) => { |
| 528 | + let names: Vec<_> = (0..v.fields.iter().count()) |
| 529 | + .map(|index| format_ident!("a{}", index)) |
| 530 | + .collect(); |
| 531 | + quote! { |
| 532 | + #type_name::#variant( #(ref #names),* ) => { |
| 533 | + let mut result = R::new(); |
| 534 | + #( |
| 535 | + result = result.combine(#names.visit_with(visitor, outer_binder)); |
| 536 | + if result.return_early() { return result; } |
| 537 | + )* |
| 538 | + result |
| 539 | + } |
| 540 | + } |
| 541 | + } |
| 542 | + |
| 543 | + syn::Fields::Unit => { |
| 544 | + quote! { |
| 545 | + #type_name::#variant => R::new(), |
| 546 | + } |
| 547 | + } |
| 548 | + } |
| 549 | + }); |
| 550 | + quote! { |
| 551 | + match *self { |
| 552 | + #(#matches)* |
| 553 | + } |
| 554 | + } |
| 555 | + } |
| 556 | + Data::Union(..) => panic!("Visit can not be derived for unions"), |
| 557 | + } |
| 558 | +} |
0 commit comments