1
1
use proc_macro2:: { Ident , Span , TokenStream } ;
2
2
use quote:: { quote, ToTokens } ;
3
- use std:: collections:: HashSet ;
4
- use syn:: { parse_quote, visit, Field , Generics , Lifetime } ;
3
+ use std:: {
4
+ collections:: { hash_map:: Entry , HashMap } ,
5
+ marker:: PhantomData ,
6
+ } ;
7
+ use syn:: {
8
+ parse:: { Error , Parse , ParseStream } ,
9
+ parse_quote,
10
+ spanned:: Spanned ,
11
+ visit, Attribute , Field , Generics , Lifetime , LitStr , Token ,
12
+ } ;
5
13
6
14
#[ cfg( test) ]
7
15
mod tests;
@@ -16,6 +24,50 @@ fn gen_param(suffix: impl ToString, existing: &Generics) -> Ident {
16
24
Ident :: new ( & suffix, Span :: call_site ( ) )
17
25
}
18
26
27
+ mod kw {
28
+ syn:: custom_keyword!( because_boring) ; // only applicable on internable fields
29
+ syn:: custom_keyword!( despite_potential_miscompilation_because) ; // always applicable on internable types and fields
30
+ }
31
+
32
+ trait SkipReasonKeyword : Default + Parse + ToTokens {
33
+ const HAS_EXPLANATION : bool ;
34
+ }
35
+ impl SkipReasonKeyword for kw:: because_boring {
36
+ const HAS_EXPLANATION : bool = false ;
37
+ }
38
+ impl SkipReasonKeyword for kw:: despite_potential_miscompilation_because {
39
+ const HAS_EXPLANATION : bool = true ;
40
+ }
41
+
42
+ struct SkipReason < K > ( PhantomData < K > ) ;
43
+
44
+ impl < K : SkipReasonKeyword > Parse for SkipReason < K > {
45
+ fn parse ( input : ParseStream < ' _ > ) -> Result < Self , Error > {
46
+ input. parse :: < K > ( ) ?;
47
+ if K :: HAS_EXPLANATION {
48
+ input. parse :: < Token ! [ =] > ( ) ?;
49
+ let reason = input. parse :: < LitStr > ( ) ?;
50
+ if reason. value ( ) . trim ( ) . is_empty ( ) {
51
+ return Err ( Error :: new_spanned ( reason, "skip reason must be a non-empty string" ) ) ;
52
+ }
53
+ }
54
+ Ok ( Self ( PhantomData ) )
55
+ }
56
+ }
57
+
58
+ #[ derive( Clone , Copy ) ]
59
+ enum WhenToSkip {
60
+ Forced ,
61
+ Never ,
62
+ Always ( Span ) ,
63
+ }
64
+
65
+ impl WhenToSkip {
66
+ fn is_skipped ( & self ) -> bool {
67
+ !matches ! ( self , WhenToSkip :: Never )
68
+ }
69
+ }
70
+
19
71
#[ derive( Clone , Copy , PartialEq ) ]
20
72
enum Type {
21
73
/// Describes a type that is not parameterised by the interner, and therefore cannot
@@ -31,6 +83,36 @@ enum Type {
31
83
}
32
84
use Type :: * ;
33
85
86
+ impl WhenToSkip {
87
+ fn find < const IS_TYPE : bool > ( attrs : & [ Attribute ] , ty : Type ) -> Result < WhenToSkip , Error > {
88
+ let mut iter = attrs. iter ( ) . filter ( |& attr| attr. path ( ) . is_ident ( "skip_traversal" ) ) ;
89
+ let Some ( attr) = iter. next ( ) else {
90
+ return Ok ( Self :: Never ) ;
91
+ } ;
92
+ if let Some ( next) = iter. next ( ) {
93
+ return Err ( Error :: new_spanned (
94
+ next,
95
+ "at most one skip_traversal attribute is supported" ,
96
+ ) ) ;
97
+ }
98
+
99
+ attr. meta
100
+ . require_list ( )
101
+ . and_then ( |list| match ( IS_TYPE , ty) {
102
+ ( true , Boring ) => Err ( Error :: new_spanned ( attr, "boring types are always skipped, so this attribute is superfluous" ) ) ,
103
+ ( true , _) | ( false , NotGeneric ) => list. parse_args :: < SkipReason < kw:: despite_potential_miscompilation_because > > ( ) . and ( Ok ( Self :: Forced ) ) ,
104
+ ( false , Generic ) => list. parse_args :: < SkipReason < kw:: because_boring > > ( ) . and_then ( |_| Ok ( Self :: Always ( attr. span ( ) ) ) )
105
+ . or_else ( |_| list. parse_args :: < SkipReason < kw:: despite_potential_miscompilation_because > > ( ) . and ( Ok ( Self :: Forced ) ) )
106
+ . or ( Err ( Error :: new_spanned ( attr, "\
107
+ Justification must be provided for skipping this potentially interesting field, by specifying EITHER:\n \
108
+ `because_boring` if concrete instances never actually contain anything of interest (enforced by the compiler); OR\n \
109
+ `despite_potential_miscompilation_because = \" <reason>\" ` if this field should always be skipped regardless\
110
+ ") ) ) ,
111
+ ( false , Boring ) => Err ( Error :: new_spanned ( attr, "boring fields are always skipped, so this attribute is superfluous" ) ) ,
112
+ } )
113
+ }
114
+ }
115
+
34
116
pub struct Interner < ' a > ( Option < & ' a Lifetime > ) ;
35
117
36
118
impl < ' a > Interner < ' a > {
@@ -67,13 +149,13 @@ pub trait Traversable {
67
149
/// Any supertraits that this trait is required to implement.
68
150
fn supertraits ( interner : & Interner < ' _ > ) -> TokenStream ;
69
151
70
- /// A traversal of this trait upon the `bind` expression.
71
- fn traverse ( bind : TokenStream ) -> TokenStream ;
152
+ /// A (`noop`) traversal of this trait upon the `bind` expression.
153
+ fn traverse ( bind : TokenStream , noop : bool ) -> TokenStream ;
72
154
73
155
/// A `match` arm for `variant`, where `f` generates the tokens for each binding.
74
156
fn arm (
75
157
variant : & synstructure:: VariantInfo < ' _ > ,
76
- f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
158
+ f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> Result < TokenStream , Error > ,
77
159
) -> TokenStream ;
78
160
79
161
/// The body of an implementation given the `interner`, `traverser` and match expression `body`.
@@ -91,15 +173,19 @@ impl Traversable for Foldable {
91
173
fn supertraits ( interner : & Interner < ' _ > ) -> TokenStream {
92
174
Visitable :: traversable ( interner)
93
175
}
94
- fn traverse ( bind : TokenStream ) -> TokenStream {
95
- quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. try_fold_with( folder) ) ? }
176
+ fn traverse ( bind : TokenStream , noop : bool ) -> TokenStream {
177
+ if noop {
178
+ bind
179
+ } else {
180
+ quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. try_fold_with( folder) ) ? }
181
+ }
96
182
}
97
183
fn arm (
98
184
variant : & synstructure:: VariantInfo < ' _ > ,
99
- mut f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
185
+ mut f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> Result < TokenStream , Error > ,
100
186
) -> TokenStream {
101
187
let bindings = variant. bindings ( ) ;
102
- variant. construct ( |_, index| f ( & bindings[ index] ) )
188
+ variant. construct ( |_, index| f ( & bindings[ index] ) . unwrap_or_else ( Error :: into_compile_error ) )
103
189
}
104
190
fn impl_body (
105
191
interner : Interner < ' _ > ,
@@ -124,14 +210,23 @@ impl Traversable for Visitable {
124
210
fn supertraits ( _: & Interner < ' _ > ) -> TokenStream {
125
211
quote ! { :: core:: clone:: Clone + :: core:: fmt:: Debug }
126
212
}
127
- fn traverse ( bind : TokenStream ) -> TokenStream {
128
- quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. visit_with( visitor) ) ?; }
213
+ fn traverse ( bind : TokenStream , noop : bool ) -> TokenStream {
214
+ if noop {
215
+ quote ! { }
216
+ } else {
217
+ quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. visit_with( visitor) ) ?; }
218
+ }
129
219
}
130
220
fn arm (
131
221
variant : & synstructure:: VariantInfo < ' _ > ,
132
- f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
222
+ f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> Result < TokenStream , Error > ,
133
223
) -> TokenStream {
134
- variant. bindings ( ) . iter ( ) . map ( f) . collect ( )
224
+ variant
225
+ . bindings ( )
226
+ . iter ( )
227
+ . map ( f)
228
+ . collect :: < Result < _ , _ > > ( )
229
+ . unwrap_or_else ( Error :: into_compile_error)
135
230
}
136
231
fn impl_body (
137
232
interner : Interner < ' _ > ,
@@ -158,12 +253,14 @@ impl Interner<'_> {
158
253
referenced_ty_params : & [ & Ident ] ,
159
254
fields : impl IntoIterator < Item = & ' a Field > ,
160
255
) -> Type {
256
+ use visit:: Visit ;
257
+
161
258
struct Visitor < ' a > {
162
259
interner : & ' a Lifetime ,
163
260
contains_interner : bool ,
164
261
}
165
262
166
- impl visit :: Visit < ' _ > for Visitor < ' _ > {
263
+ impl Visit < ' _ > for Visitor < ' _ > {
167
264
fn visit_lifetime ( & mut self , i : & Lifetime ) {
168
265
if i == self . interner {
169
266
self . contains_interner = true ;
@@ -180,7 +277,7 @@ impl Interner<'_> {
180
277
Some ( interner)
181
278
if fields. into_iter ( ) . any ( |field| {
182
279
let mut visitor = Visitor { interner, contains_interner : false } ;
183
- visit :: visit_type ( & mut visitor , & field. ty ) ;
280
+ visitor . visit_type ( & field. ty ) ;
184
281
visitor. contains_interner
185
282
} ) =>
186
283
{
@@ -194,7 +291,11 @@ impl Interner<'_> {
194
291
195
292
pub fn traversable_derive < T : Traversable > (
196
293
mut structure : synstructure:: Structure < ' _ > ,
197
- ) -> TokenStream {
294
+ ) -> Result < TokenStream , Error > {
295
+ use WhenToSkip :: * ;
296
+
297
+ let skip_traversal = quote ! { :: rustc_middle:: ty:: BoringTraversable } ;
298
+
198
299
let ast = structure. ast ( ) ;
199
300
200
301
let interner = Interner :: resolve ( & ast. generics ) ;
@@ -205,41 +306,88 @@ pub fn traversable_derive<T: Traversable>(
205
306
structure. add_bounds ( synstructure:: AddBounds :: None ) ;
206
307
structure. bind_with ( |_| synstructure:: BindStyle :: Move ) ;
207
308
208
- if interner. 0 . is_none ( ) {
309
+ let not_generic = if interner. 0 . is_none ( ) {
209
310
structure. add_impl_generic ( parse_quote ! { ' tcx } ) ;
210
- }
311
+ Boring
312
+ } else {
313
+ NotGeneric
314
+ } ;
211
315
212
316
// If our derived implementation will be generic over the traversable type, then we must
213
317
// constrain it to only those generic combinations that satisfy the traversable trait's
214
318
// supertraits.
215
- if ast. generics . type_params ( ) . next ( ) . is_some ( ) {
319
+ let ty = if ast. generics . type_params ( ) . next ( ) . is_some ( ) {
216
320
let supertraits = T :: supertraits ( & interner) ;
217
321
structure. add_where_predicate ( parse_quote ! { Self : #supertraits } ) ;
218
- }
322
+ Generic
323
+ } else {
324
+ not_generic
325
+ } ;
219
326
220
- // We add predicates to each generic field type, rather than to our generic type parameters.
221
- // This results in a "perfect derive", but it can result in trait solver cycles if any type
222
- // parameters are involved in recursive type definitions; fortunately that is not the case (yet).
223
- let mut predicates = HashSet :: new ( ) ;
224
- let arms = structure. each_variant ( |variant| {
225
- let variant_ty = interner. type_of ( & variant. referenced_ty_params ( ) , variant. ast ( ) . fields ) ;
226
- T :: arm ( variant, |bind| {
227
- if variant_ty == Generic {
327
+ let when_to_skip = WhenToSkip :: find :: < true > ( & ast. attrs , ty) ?;
328
+ let body = if when_to_skip. is_skipped ( ) {
329
+ if let Always ( _) = when_to_skip {
330
+ structure. add_where_predicate ( parse_quote ! { Self : #skip_traversal } ) ;
331
+ }
332
+ T :: traverse ( quote ! { self } , true )
333
+ } else {
334
+ // We add predicates to each generic field type, rather than to our generic type parameters.
335
+ // This results in a "perfect derive" that avoids having to propagate `#[skip_traversal]` annotations
336
+ // into wrapping types, but it can result in trait solver cycles if any type parameters are involved
337
+ // in recursive type definitions; fortunately that is not the case (yet).
338
+ let mut predicates = HashMap :: < _ , ( _ , _ ) > :: new ( ) ;
339
+
340
+ let arms = structure. each_variant ( |variant| {
341
+ let variant_ty = interner. type_of ( & variant. referenced_ty_params ( ) , variant. ast ( ) . fields ) ;
342
+ let skipped_variant = match WhenToSkip :: find :: < false > ( variant. ast ( ) . attrs , variant_ty) {
343
+ Ok ( skip) => skip,
344
+ Err ( err) => return err. into_compile_error ( ) ,
345
+ } ;
346
+ T :: arm ( variant, |bind| {
228
347
let ast = bind. ast ( ) ;
229
- let field_ty = interner. type_of ( & bind. referenced_ty_params ( ) , [ ast] ) ;
230
- if field_ty == Generic {
231
- predicates. insert ( ast. ty . clone ( ) ) ;
232
- }
233
- }
234
- T :: traverse ( bind. into_token_stream ( ) )
235
- } )
236
- } ) ;
237
- // the order in which `where` predicates appear in rust source is irrelevant
238
- #[ allow( rustc:: potential_query_instability) ]
239
- for ty in predicates {
240
- structure. add_where_predicate ( parse_quote ! { #ty: #traversable } ) ;
241
- }
242
- let body = quote ! { match self { #arms } } ;
348
+ let is_skipped = variant_ty != Type :: Boring && {
349
+ let field_ty = interner. type_of ( & bind. referenced_ty_params ( ) , [ ast] ) ;
350
+ field_ty != Type :: Boring && {
351
+ let skipped_field = if skipped_variant. is_skipped ( ) {
352
+ skipped_variant
353
+ } else {
354
+ WhenToSkip :: find :: < false > ( & ast. attrs , field_ty) ?
355
+ } ;
356
+
357
+ match predicates. entry ( ast. ty . clone ( ) ) {
358
+ Entry :: Occupied ( existing) => match ( & mut existing. into_mut ( ) . 0 , skipped_field) {
359
+ ( Never , Never ) | ( Never , Forced ) | ( Forced , Forced ) | ( Always ( _) , Always ( _) ) => ( ) ,
360
+ ( existing @ Forced , Never ) => * existing = Never ,
361
+ ( & mut Always ( span) , _) | ( _, Always ( span) ) => return Err ( Error :: new ( span, format ! ( "\
362
+ This annotation only makes sense if all fields of type `{0}` are annotated identically.\n \
363
+ In particular, the derived impl will only be applicable when `{0}: BoringTraversable` and therefore all traversals of `{0}` will be no-ops;\n \
364
+ accordingly it makes no sense for other fields of type `{0}` to omit `#[skip_traversal]` or to include `despite_potential_miscompilation_because`.\
365
+ ", ast. ty. to_token_stream( ) ) ) ) ,
366
+ } ,
367
+ Entry :: Vacant ( entry) => { entry. insert ( ( skipped_field, bind. referenced_ty_params ( ) . is_empty ( ) ) ) ; }
368
+ }
369
+
370
+ skipped_field. is_skipped ( )
371
+ }
372
+ } ;
373
+
374
+ Ok ( T :: traverse ( bind. into_token_stream ( ) , is_skipped) )
375
+ } )
376
+ } ) ;
377
+
378
+ // the order in which `where` predicates appear in rust source is irrelevant
379
+ #[ allow( rustc:: potential_query_instability) ]
380
+ for ( ty, ( when_to_skip, ignore) ) in predicates {
381
+ let bound = match when_to_skip {
382
+ Always ( _) => & skip_traversal,
383
+ // we only need to add traversable predicate for generic types
384
+ Never if !ignore => & traversable,
385
+ _ => continue ,
386
+ } ;
387
+ structure. add_where_predicate ( parse_quote ! { #ty: #bound } ) ;
388
+ }
389
+ quote ! { match self { #arms } }
390
+ } ;
243
391
244
- structure. bound_impl ( traversable, T :: impl_body ( interner, traverser, body) )
392
+ Ok ( structure. bound_impl ( traversable, T :: impl_body ( interner, traverser, body) ) )
245
393
}
0 commit comments