Skip to content

Commit 76e9804

Browse files
committed
Enable skipping of derived traversals, with reason
1 parent aa6761c commit 76e9804

File tree

3 files changed

+464
-88
lines changed

3 files changed

+464
-88
lines changed

compiler/rustc_macros/src/lib.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,24 @@ decl_derive!([TyEncodable] => serialize::type_encodable_derive);
7474
decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive);
7575
decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive);
7676
decl_derive!(
77-
[TypeFoldable] =>
77+
[TypeFoldable, attributes(skip_traversal)] =>
7878
/// Derives `TypeFoldable` for the annotated `struct` or `enum` (`union` is not supported).
7979
///
8080
/// Folds will produce a value of the same struct or enum variant as the input, with each field
8181
/// respectively folded (in definition order) using the `TypeFoldable` implementation for its
8282
/// type if it has one. Fields of non-generic types that do not contain anything that may be of
8383
/// interest to folders will automatically be left unchanged whether the type implements
84-
/// `TypeFoldable` or not.
84+
/// `TypeFoldable` or not; the same behaviour can be achieved for fields of generic types by
85+
/// applying `#[skip_traversal(because_boring)]` to the field definition (or even to a variant
86+
/// definition if it should apply to all fields therein), but the derived implementation will
87+
/// only be applicable to concrete types where such annotated fields do not contain anything
88+
/// that may be of interest to folders (thus preventing fields from being left unchanged
89+
/// erroneously).
90+
///
91+
/// In some rare situations, it may be desirable to skip folding of an item or field (or
92+
/// variant) that might otherwise be of interest to folders: **this is dangerous and could lead
93+
/// to miscompilation if user expectations are not met!** Nevertheless, such can be achieved
94+
/// via a `#[skip_traversal(despite_potential_miscompilation_because = "<reason>"]` attribute.
8595
///
8696
/// If the annotated type has a `'tcx` lifetime parameter, then that will be used as the
8797
/// lifetime for the type context/interner; otherwise the lifetime of the type context/interner
@@ -97,13 +107,23 @@ decl_derive!(
97107
traversable::traversable_derive::<traversable::Foldable>
98108
);
99109
decl_derive!(
100-
[TypeVisitable] =>
110+
[TypeVisitable, attributes(skip_traversal)] =>
101111
/// Derives `TypeVisitable` for the annotated `struct` or `enum` (`union` is not supported).
102112
///
103113
/// Each field of the struct or enum variant will be visited (in definition order) using the
104114
/// `TypeVisitable` implementation for its type if it has one. Fields of non-generic types that
105115
/// do not contain anything that may be of interest to visitors will automatically be skipped
106-
/// whether the type implements `TypeVisitable` or not.
116+
/// whether the type implements `TypeVisitable` or not; the same behaviour can be achieved for
117+
/// fields of generic types by applying `#[skip_traversal(because_boring)]` to the field
118+
/// definition (or even to a variant definition if it should apply to all fields therein), but
119+
/// the derived implementation will only be applicable to concrete types where such annotated
120+
/// fields do not contain anything that may be of interest to visitors (thus preventing fields
121+
/// from being so skipped erroneously).
122+
///
123+
/// In some rare situations, it may be desirable to skip visiting of an item or field (or
124+
/// variant) that might otherwise be of interest to visitors: **this is dangerous and could lead
125+
/// to miscompilation if user expectations are not met!** Nevertheless, such can be achieved
126+
/// via a `#[skip_traversal(despite_potential_miscompilation_because = "<reason>"]` attribute.
107127
///
108128
/// If the annotated type has a `'tcx` lifetime parameter, then that will be used as the
109129
/// lifetime for the type context/interner; otherwise the lifetime of the type context/interner

compiler/rustc_macros/src/traversable.rs

Lines changed: 191 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
use proc_macro2::{Ident, Span, TokenStream};
22
use quote::{quote, ToTokens};
3-
use std::collections::HashSet;
4-
use syn::{parse_quote, visit, Field, Generics, Lifetime};
3+
use std::{
4+
collections::{hash_map::Entry, HashMap},
5+
marker::PhantomData,
6+
};
7+
use syn::{
8+
parse::{Error, Parse, ParseStream},
9+
parse_quote,
10+
spanned::Spanned,
11+
visit, Attribute, Field, Generics, Lifetime, LitStr, Token,
12+
};
513

614
#[cfg(test)]
715
mod tests;
@@ -16,6 +24,50 @@ fn gen_param(suffix: impl ToString, existing: &Generics) -> Ident {
1624
Ident::new(&suffix, Span::call_site())
1725
}
1826

27+
mod kw {
28+
syn::custom_keyword!(because_boring); // only applicable on internable fields
29+
syn::custom_keyword!(despite_potential_miscompilation_because); // always applicable on internable types and fields
30+
}
31+
32+
trait SkipReasonKeyword: Default + Parse + ToTokens {
33+
const HAS_EXPLANATION: bool;
34+
}
35+
impl SkipReasonKeyword for kw::because_boring {
36+
const HAS_EXPLANATION: bool = false;
37+
}
38+
impl SkipReasonKeyword for kw::despite_potential_miscompilation_because {
39+
const HAS_EXPLANATION: bool = true;
40+
}
41+
42+
struct SkipReason<K>(PhantomData<K>);
43+
44+
impl<K: SkipReasonKeyword> Parse for SkipReason<K> {
45+
fn parse(input: ParseStream<'_>) -> Result<Self, Error> {
46+
input.parse::<K>()?;
47+
if K::HAS_EXPLANATION {
48+
input.parse::<Token![=]>()?;
49+
let reason = input.parse::<LitStr>()?;
50+
if reason.value().trim().is_empty() {
51+
return Err(Error::new_spanned(reason, "skip reason must be a non-empty string"));
52+
}
53+
}
54+
Ok(Self(PhantomData))
55+
}
56+
}
57+
58+
#[derive(Clone, Copy)]
59+
enum WhenToSkip {
60+
Forced,
61+
Never,
62+
Always(Span),
63+
}
64+
65+
impl WhenToSkip {
66+
fn is_skipped(&self) -> bool {
67+
!matches!(self, WhenToSkip::Never)
68+
}
69+
}
70+
1971
#[derive(Clone, Copy, PartialEq)]
2072
enum Type {
2173
/// Describes a type that is not parameterised by the interner, and therefore cannot
@@ -31,6 +83,36 @@ enum Type {
3183
}
3284
use Type::*;
3385

86+
impl WhenToSkip {
87+
fn find<const IS_TYPE: bool>(attrs: &[Attribute], ty: Type) -> Result<WhenToSkip, Error> {
88+
let mut iter = attrs.iter().filter(|&attr| attr.path().is_ident("skip_traversal"));
89+
let Some(attr) = iter.next() else {
90+
return Ok(Self::Never);
91+
};
92+
if let Some(next) = iter.next() {
93+
return Err(Error::new_spanned(
94+
next,
95+
"at most one skip_traversal attribute is supported",
96+
));
97+
}
98+
99+
attr.meta
100+
.require_list()
101+
.and_then(|list| match (IS_TYPE, ty) {
102+
(true, Boring) => Err(Error::new_spanned(attr, "boring types are always skipped, so this attribute is superfluous")),
103+
(true, _) | (false, NotGeneric) => list.parse_args::<SkipReason<kw::despite_potential_miscompilation_because>>().and(Ok(Self::Forced)),
104+
(false, Generic) => list.parse_args::<SkipReason<kw::because_boring>>().and_then(|_| Ok(Self::Always(attr.span())))
105+
.or_else(|_| list.parse_args::<SkipReason<kw::despite_potential_miscompilation_because>>().and(Ok(Self::Forced)))
106+
.or(Err(Error::new_spanned(attr, "\
107+
Justification must be provided for skipping this potentially interesting field, by specifying EITHER:\n\
108+
`because_boring` if concrete instances never actually contain anything of interest (enforced by the compiler); OR\n\
109+
`despite_potential_miscompilation_because = \"<reason>\"` if this field should always be skipped regardless\
110+
"))),
111+
(false, Boring) => Err(Error::new_spanned(attr, "boring fields are always skipped, so this attribute is superfluous")),
112+
})
113+
}
114+
}
115+
34116
pub struct Interner<'a>(Option<&'a Lifetime>);
35117

36118
impl<'a> Interner<'a> {
@@ -67,13 +149,13 @@ pub trait Traversable {
67149
/// Any supertraits that this trait is required to implement.
68150
fn supertraits(interner: &Interner<'_>) -> TokenStream;
69151

70-
/// A traversal of this trait upon the `bind` expression.
71-
fn traverse(bind: TokenStream) -> TokenStream;
152+
/// A (`noop`) traversal of this trait upon the `bind` expression.
153+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream;
72154

73155
/// A `match` arm for `variant`, where `f` generates the tokens for each binding.
74156
fn arm(
75157
variant: &synstructure::VariantInfo<'_>,
76-
f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
158+
f: impl FnMut(&synstructure::BindingInfo<'_>) -> Result<TokenStream, Error>,
77159
) -> TokenStream;
78160

79161
/// The body of an implementation given the `interner`, `traverser` and match expression `body`.
@@ -91,15 +173,19 @@ impl Traversable for Foldable {
91173
fn supertraits(interner: &Interner<'_>) -> TokenStream {
92174
Visitable::traversable(interner)
93175
}
94-
fn traverse(bind: TokenStream) -> TokenStream {
95-
quote! { ::rustc_middle::ty::noop_traversal_if_boring!(#bind.try_fold_with(folder))? }
176+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream {
177+
if noop {
178+
bind
179+
} else {
180+
quote! { ::rustc_middle::ty::noop_traversal_if_boring!(#bind.try_fold_with(folder))? }
181+
}
96182
}
97183
fn arm(
98184
variant: &synstructure::VariantInfo<'_>,
99-
mut f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
185+
mut f: impl FnMut(&synstructure::BindingInfo<'_>) -> Result<TokenStream, Error>,
100186
) -> TokenStream {
101187
let bindings = variant.bindings();
102-
variant.construct(|_, index| f(&bindings[index]))
188+
variant.construct(|_, index| f(&bindings[index]).unwrap_or_else(Error::into_compile_error))
103189
}
104190
fn impl_body(
105191
interner: Interner<'_>,
@@ -124,14 +210,23 @@ impl Traversable for Visitable {
124210
fn supertraits(_: &Interner<'_>) -> TokenStream {
125211
quote! { ::core::clone::Clone + ::core::fmt::Debug }
126212
}
127-
fn traverse(bind: TokenStream) -> TokenStream {
128-
quote! { ::rustc_middle::ty::noop_traversal_if_boring!(#bind.visit_with(visitor))?; }
213+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream {
214+
if noop {
215+
quote! {}
216+
} else {
217+
quote! { ::rustc_middle::ty::noop_traversal_if_boring!(#bind.visit_with(visitor))?; }
218+
}
129219
}
130220
fn arm(
131221
variant: &synstructure::VariantInfo<'_>,
132-
f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
222+
f: impl FnMut(&synstructure::BindingInfo<'_>) -> Result<TokenStream, Error>,
133223
) -> TokenStream {
134-
variant.bindings().iter().map(f).collect()
224+
variant
225+
.bindings()
226+
.iter()
227+
.map(f)
228+
.collect::<Result<_, _>>()
229+
.unwrap_or_else(Error::into_compile_error)
135230
}
136231
fn impl_body(
137232
interner: Interner<'_>,
@@ -158,12 +253,14 @@ impl Interner<'_> {
158253
referenced_ty_params: &[&Ident],
159254
fields: impl IntoIterator<Item = &'a Field>,
160255
) -> Type {
256+
use visit::Visit;
257+
161258
struct Visitor<'a> {
162259
interner: &'a Lifetime,
163260
contains_interner: bool,
164261
}
165262

166-
impl visit::Visit<'_> for Visitor<'_> {
263+
impl Visit<'_> for Visitor<'_> {
167264
fn visit_lifetime(&mut self, i: &Lifetime) {
168265
if i == self.interner {
169266
self.contains_interner = true;
@@ -180,7 +277,7 @@ impl Interner<'_> {
180277
Some(interner)
181278
if fields.into_iter().any(|field| {
182279
let mut visitor = Visitor { interner, contains_interner: false };
183-
visit::visit_type(&mut visitor, &field.ty);
280+
visitor.visit_type(&field.ty);
184281
visitor.contains_interner
185282
}) =>
186283
{
@@ -194,7 +291,11 @@ impl Interner<'_> {
194291

195292
pub fn traversable_derive<T: Traversable>(
196293
mut structure: synstructure::Structure<'_>,
197-
) -> TokenStream {
294+
) -> Result<TokenStream, Error> {
295+
use WhenToSkip::*;
296+
297+
let skip_traversal = quote! { ::rustc_middle::ty::BoringTraversable };
298+
198299
let ast = structure.ast();
199300

200301
let interner = Interner::resolve(&ast.generics);
@@ -205,41 +306,88 @@ pub fn traversable_derive<T: Traversable>(
205306
structure.add_bounds(synstructure::AddBounds::None);
206307
structure.bind_with(|_| synstructure::BindStyle::Move);
207308

208-
if interner.0.is_none() {
309+
let not_generic = if interner.0.is_none() {
209310
structure.add_impl_generic(parse_quote! { 'tcx });
210-
}
311+
Boring
312+
} else {
313+
NotGeneric
314+
};
211315

212316
// If our derived implementation will be generic over the traversable type, then we must
213317
// constrain it to only those generic combinations that satisfy the traversable trait's
214318
// supertraits.
215-
if ast.generics.type_params().next().is_some() {
319+
let ty = if ast.generics.type_params().next().is_some() {
216320
let supertraits = T::supertraits(&interner);
217321
structure.add_where_predicate(parse_quote! { Self: #supertraits });
218-
}
322+
Generic
323+
} else {
324+
not_generic
325+
};
219326

220-
// We add predicates to each generic field type, rather than to our generic type parameters.
221-
// This results in a "perfect derive", but it can result in trait solver cycles if any type
222-
// parameters are involved in recursive type definitions; fortunately that is not the case (yet).
223-
let mut predicates = HashSet::new();
224-
let arms = structure.each_variant(|variant| {
225-
let variant_ty = interner.type_of(&variant.referenced_ty_params(), variant.ast().fields);
226-
T::arm(variant, |bind| {
227-
if variant_ty == Generic {
327+
let when_to_skip = WhenToSkip::find::<true>(&ast.attrs, ty)?;
328+
let body = if when_to_skip.is_skipped() {
329+
if let Always(_) = when_to_skip {
330+
structure.add_where_predicate(parse_quote! { Self: #skip_traversal });
331+
}
332+
T::traverse(quote! { self }, true)
333+
} else {
334+
// We add predicates to each generic field type, rather than to our generic type parameters.
335+
// This results in a "perfect derive" that avoids having to propagate `#[skip_traversal]` annotations
336+
// into wrapping types, but it can result in trait solver cycles if any type parameters are involved
337+
// in recursive type definitions; fortunately that is not the case (yet).
338+
let mut predicates = HashMap::<_, (_, _)>::new();
339+
340+
let arms = structure.each_variant(|variant| {
341+
let variant_ty = interner.type_of(&variant.referenced_ty_params(), variant.ast().fields);
342+
let skipped_variant = match WhenToSkip::find::<false>(variant.ast().attrs, variant_ty) {
343+
Ok(skip) => skip,
344+
Err(err) => return err.into_compile_error(),
345+
};
346+
T::arm(variant, |bind| {
228347
let ast = bind.ast();
229-
let field_ty = interner.type_of(&bind.referenced_ty_params(), [ast]);
230-
if field_ty == Generic {
231-
predicates.insert(ast.ty.clone());
232-
}
233-
}
234-
T::traverse(bind.into_token_stream())
235-
})
236-
});
237-
// the order in which `where` predicates appear in rust source is irrelevant
238-
#[allow(rustc::potential_query_instability)]
239-
for ty in predicates {
240-
structure.add_where_predicate(parse_quote! { #ty: #traversable });
241-
}
242-
let body = quote! { match self { #arms } };
348+
let is_skipped = variant_ty != Type::Boring && {
349+
let field_ty = interner.type_of(&bind.referenced_ty_params(), [ast]);
350+
field_ty != Type::Boring && {
351+
let skipped_field = if skipped_variant.is_skipped() {
352+
skipped_variant
353+
} else {
354+
WhenToSkip::find::<false>(&ast.attrs, field_ty)?
355+
};
356+
357+
match predicates.entry(ast.ty.clone()) {
358+
Entry::Occupied(existing) => match (&mut existing.into_mut().0, skipped_field) {
359+
(Never, Never) | (Never, Forced) | (Forced, Forced) | (Always(_), Always(_)) => (),
360+
(existing @ Forced, Never) => *existing = Never,
361+
(&mut Always(span), _) | (_, Always(span)) => return Err(Error::new(span, format!("\
362+
This annotation only makes sense if all fields of type `{0}` are annotated identically.\n\
363+
In particular, the derived impl will only be applicable when `{0}: BoringTraversable` and therefore all traversals of `{0}` will be no-ops;\n\
364+
accordingly it makes no sense for other fields of type `{0}` to omit `#[skip_traversal]` or to include `despite_potential_miscompilation_because`.\
365+
", ast.ty.to_token_stream()))),
366+
},
367+
Entry::Vacant(entry) => { entry.insert((skipped_field, bind.referenced_ty_params().is_empty())); }
368+
}
369+
370+
skipped_field.is_skipped()
371+
}
372+
};
373+
374+
Ok(T::traverse(bind.into_token_stream(), is_skipped))
375+
})
376+
});
377+
378+
// the order in which `where` predicates appear in rust source is irrelevant
379+
#[allow(rustc::potential_query_instability)]
380+
for (ty, (when_to_skip, ignore)) in predicates {
381+
let bound = match when_to_skip {
382+
Always(_) => &skip_traversal,
383+
// we only need to add traversable predicate for generic types
384+
Never if !ignore => &traversable,
385+
_ => continue,
386+
};
387+
structure.add_where_predicate(parse_quote! { #ty: #bound });
388+
}
389+
quote! { match self { #arms } }
390+
};
243391

244-
structure.bound_impl(traversable, T::impl_body(interner, traverser, body))
392+
Ok(structure.bound_impl(traversable, T::impl_body(interner, traverser, body)))
245393
}

0 commit comments

Comments
 (0)