1
- use proc_macro2:: TokenStream ;
1
+ use proc_macro2:: { Ident , Span , TokenStream } ;
2
2
use quote:: { quote, ToTokens } ;
3
- use syn:: parse_quote;
3
+ use std:: collections:: HashSet ;
4
+ use syn:: { parse_quote, visit, Field , Generics , Lifetime } ;
5
+
6
+ #[ cfg( test) ]
7
+ mod tests;
8
+
9
+ /// Generate a type parameter with the given `suffix` that does not conflict with
10
+ /// any of the `existing` generics.
11
+ fn gen_param ( suffix : impl ToString , existing : & Generics ) -> Ident {
12
+ let mut suffix = suffix. to_string ( ) ;
13
+ while existing. type_params ( ) . any ( |t| t. ident == suffix) {
14
+ suffix. insert ( 0 , '_' ) ;
15
+ }
16
+ Ident :: new ( & suffix, Span :: call_site ( ) )
17
+ }
18
+
19
+ #[ derive( Clone , Copy , PartialEq ) ]
20
+ enum Type {
21
+ /// Describes a type that is not parameterised by the interner, and therefore cannot
22
+ /// be of any interest to traversers.
23
+ Boring ,
24
+
25
+ /// Describes a type that is parameterised by the interner lifetime `'tcx` but that is
26
+ /// otherwise not generic.
27
+ NotGeneric ,
28
+
29
+ /// Describes a type that is generic.
30
+ Generic ,
31
+ }
32
+ use Type :: * ;
33
+
34
+ pub struct Interner < ' a > ( Option < & ' a Lifetime > ) ;
35
+
36
+ impl < ' a > Interner < ' a > {
37
+ /// Return the `TyCtxt` interner for the given `structure`.
38
+ ///
39
+ /// If the input represented by `structure` has a `'tcx` lifetime parameter, then that will be used
40
+ /// used as the lifetime of the `TyCtxt`. Otherwise a `'tcx` lifetime parameter that is unrelated
41
+ /// to the input will be used.
42
+ fn resolve ( generics : & ' a Generics ) -> Self {
43
+ Self (
44
+ generics
45
+ . lifetimes ( )
46
+ . find_map ( |def| ( def. lifetime . ident == "tcx" ) . then_some ( & def. lifetime ) ) ,
47
+ )
48
+ }
49
+ }
50
+
51
+ impl ToTokens for Interner < ' _ > {
52
+ fn to_tokens ( & self , tokens : & mut TokenStream ) {
53
+ let default = parse_quote ! { ' tcx } ;
54
+ let lt = self . 0 . unwrap_or ( & default) ;
55
+ tokens. extend ( quote ! { :: rustc_middle:: ty:: TyCtxt <#lt> } ) ;
56
+ }
57
+ }
4
58
5
59
pub struct Foldable ;
6
60
pub struct Visitable ;
7
61
8
62
/// An abstraction over traversable traits.
9
63
pub trait Traversable {
10
- /// The trait that this `Traversable` represents.
11
- fn traversable ( ) -> TokenStream ;
12
-
13
- /// The `match` arms for a traversal of this type.
14
- fn arms ( structure : & mut synstructure:: Structure < ' _ > ) -> TokenStream ;
15
-
16
- /// The body of an implementation given the match `arms`.
17
- fn impl_body ( arms : impl ToTokens ) -> TokenStream ;
64
+ /// The trait that this `Traversable` represents, parameterised by `interner`.
65
+ fn traversable ( interner : & Interner < ' _ > ) -> TokenStream ;
66
+
67
+ /// Any supertraits that this trait is required to implement.
68
+ fn supertraits ( interner : & Interner < ' _ > ) -> TokenStream ;
69
+
70
+ /// A traversal of this trait upon the `bind` expression.
71
+ fn traverse ( bind : TokenStream ) -> TokenStream ;
72
+
73
+ /// A `match` arm for `variant`, where `f` generates the tokens for each binding.
74
+ fn arm (
75
+ variant : & synstructure:: VariantInfo < ' _ > ,
76
+ f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
77
+ ) -> TokenStream ;
78
+
79
+ /// The body of an implementation given the `interner`, `traverser` and match expression `body`.
80
+ fn impl_body (
81
+ interner : Interner < ' _ > ,
82
+ traverser : impl ToTokens ,
83
+ body : impl ToTokens ,
84
+ ) -> TokenStream ;
18
85
}
19
86
20
87
impl Traversable for Foldable {
21
- fn traversable ( ) -> TokenStream {
22
- quote ! { :: rustc_middle:: ty:: fold:: TypeFoldable <:: rustc_middle:: ty:: TyCtxt <' tcx>> }
23
- }
24
- fn arms ( structure : & mut synstructure:: Structure < ' _ > ) -> TokenStream {
25
- structure. each_variant ( |vi| {
26
- let bindings = vi. bindings ( ) ;
27
- vi. construct ( |_, index| {
28
- let bind = & bindings[ index] ;
29
-
30
- let mut fixed = false ;
31
-
32
- // retain value of fields with #[type_foldable(identity)]
33
- bind. ast ( ) . attrs . iter ( ) . for_each ( |x| {
34
- if !x. path ( ) . is_ident ( "type_foldable" ) {
35
- return ;
36
- }
37
- let _ = x. parse_nested_meta ( |nested| {
38
- if nested. path . is_ident ( "identity" ) {
39
- fixed = true ;
40
- }
41
- Ok ( ( ) )
42
- } ) ;
43
- } ) ;
44
-
45
- if fixed {
46
- bind. to_token_stream ( )
47
- } else {
48
- quote ! {
49
- :: rustc_middle:: ty:: fold:: TypeFoldable :: try_fold_with( #bind, __folder) ?
50
- }
51
- }
52
- } )
53
- } )
88
+ fn traversable ( interner : & Interner < ' _ > ) -> TokenStream {
89
+ quote ! { :: rustc_middle:: ty:: fold:: TypeFoldable <#interner> }
90
+ }
91
+ fn supertraits ( interner : & Interner < ' _ > ) -> TokenStream {
92
+ Visitable :: traversable ( interner)
93
+ }
94
+ fn traverse ( bind : TokenStream ) -> TokenStream {
95
+ quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. try_fold_with( folder) ) ? }
54
96
}
55
- fn impl_body ( arms : impl ToTokens ) -> TokenStream {
97
+ fn arm (
98
+ variant : & synstructure:: VariantInfo < ' _ > ,
99
+ mut f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
100
+ ) -> TokenStream {
101
+ let bindings = variant. bindings ( ) ;
102
+ variant. construct ( |_, index| f ( & bindings[ index] ) )
103
+ }
104
+ fn impl_body (
105
+ interner : Interner < ' _ > ,
106
+ traverser : impl ToTokens ,
107
+ body : impl ToTokens ,
108
+ ) -> TokenStream {
56
109
quote ! {
57
- fn try_fold_with<__F : :: rustc_middle:: ty:: fold:: FallibleTypeFolder <:: rustc_middle :: ty :: TyCtxt < ' tcx> >>(
110
+ fn try_fold_with<#traverser : :: rustc_middle:: ty:: fold:: FallibleTypeFolder <#interner >>(
58
111
self ,
59
- __folder : & mut __F
60
- ) -> :: core:: result:: Result <Self , __F :: Error > {
61
- :: core:: result:: Result :: Ok ( match self { #arms } )
112
+ folder : & mut #traverser
113
+ ) -> :: core:: result:: Result <Self , #traverser :: Error > {
114
+ :: core:: result:: Result :: Ok ( #body )
62
115
}
63
116
}
64
117
}
65
118
}
66
119
67
120
impl Traversable for Visitable {
68
- fn traversable ( ) -> TokenStream {
69
- quote ! { :: rustc_middle:: ty:: visit:: TypeVisitable <:: rustc_middle :: ty :: TyCtxt < ' tcx> > }
121
+ fn traversable ( interner : & Interner < ' _ > ) -> TokenStream {
122
+ quote ! { :: rustc_middle:: ty:: visit:: TypeVisitable <#interner > }
70
123
}
71
- fn arms ( structure : & mut synstructure:: Structure < ' _ > ) -> TokenStream {
72
- // ignore fields with #[type_visitable(ignore)]
73
- structure. filter ( |bi| {
74
- let mut ignored = false ;
75
-
76
- bi. ast ( ) . attrs . iter ( ) . for_each ( |attr| {
77
- if !attr. path ( ) . is_ident ( "type_visitable" ) {
78
- return ;
79
- }
80
- let _ = attr. parse_nested_meta ( |nested| {
81
- if nested. path . is_ident ( "ignore" ) {
82
- ignored = true ;
83
- }
84
- Ok ( ( ) )
85
- } ) ;
86
- } ) ;
87
-
88
- !ignored
89
- } ) ;
90
-
91
- structure. each ( |bind| {
92
- quote ! {
93
- :: rustc_middle:: ty:: visit:: TypeVisitable :: visit_with( #bind, __visitor) ?;
94
- }
95
- } )
124
+ fn supertraits ( _: & Interner < ' _ > ) -> TokenStream {
125
+ quote ! { :: core:: clone:: Clone + :: core:: fmt:: Debug }
126
+ }
127
+ fn traverse ( bind : TokenStream ) -> TokenStream {
128
+ quote ! { :: rustc_middle:: ty:: noop_traversal_if_boring!( #bind. visit_with( visitor) ) ?; }
96
129
}
97
- fn impl_body ( arms : impl ToTokens ) -> TokenStream {
130
+ fn arm (
131
+ variant : & synstructure:: VariantInfo < ' _ > ,
132
+ f : impl FnMut ( & synstructure:: BindingInfo < ' _ > ) -> TokenStream ,
133
+ ) -> TokenStream {
134
+ variant. bindings ( ) . iter ( ) . map ( f) . collect ( )
135
+ }
136
+ fn impl_body (
137
+ interner : Interner < ' _ > ,
138
+ traverser : impl ToTokens ,
139
+ body : impl ToTokens ,
140
+ ) -> TokenStream {
98
141
quote ! {
99
- fn visit_with<__V : :: rustc_middle:: ty:: visit:: TypeVisitor <:: rustc_middle :: ty :: TyCtxt < ' tcx> >>(
142
+ fn visit_with<#traverser : :: rustc_middle:: ty:: visit:: TypeVisitor <#interner >>(
100
143
& self ,
101
- __visitor: & mut __V
102
- ) -> :: std:: ops:: ControlFlow <__V:: BreakTy > {
103
- match self { #arms }
104
- :: std:: ops:: ControlFlow :: Continue ( ( ) )
144
+ visitor: & mut #traverser
145
+ ) -> :: core:: ops:: ControlFlow <#traverser:: BreakTy > {
146
+ #body
147
+ :: core:: ops:: ControlFlow :: Continue ( ( ) )
148
+ }
149
+ }
150
+ }
151
+ }
152
+
153
+ impl Interner < ' _ > {
154
+ /// We consider a type to be internable if it references either a generic type parameter or,
155
+ /// if the interner is `TyCtxt<'tcx>`, the `'tcx` lifetime.
156
+ fn type_of < ' a > (
157
+ & self ,
158
+ referenced_ty_params : & [ & Ident ] ,
159
+ fields : impl IntoIterator < Item = & ' a Field > ,
160
+ ) -> Type {
161
+ struct Visitor < ' a > {
162
+ interner : & ' a Lifetime ,
163
+ contains_interner : bool ,
164
+ }
165
+
166
+ impl visit:: Visit < ' _ > for Visitor < ' _ > {
167
+ fn visit_lifetime ( & mut self , i : & Lifetime ) {
168
+ if i == self . interner {
169
+ self . contains_interner = true ;
170
+ } else {
171
+ visit:: visit_lifetime ( self , i)
172
+ }
173
+ }
174
+ }
175
+
176
+ if !referenced_ty_params. is_empty ( ) {
177
+ Generic
178
+ } else {
179
+ match & self . 0 {
180
+ Some ( interner)
181
+ if fields. into_iter ( ) . any ( |field| {
182
+ let mut visitor = Visitor { interner, contains_interner : false } ;
183
+ visit:: visit_type ( & mut visitor, & field. ty ) ;
184
+ visitor. contains_interner
185
+ } ) =>
186
+ {
187
+ NotGeneric
188
+ }
189
+ _ => Boring ,
105
190
}
106
191
}
107
192
}
@@ -110,17 +195,51 @@ impl Traversable for Visitable {
110
195
pub fn traversable_derive < T : Traversable > (
111
196
mut structure : synstructure:: Structure < ' _ > ,
112
197
) -> TokenStream {
113
- if let syn:: Data :: Union ( _) = structure. ast ( ) . data {
114
- panic ! ( "cannot derive on union" )
115
- }
198
+ let ast = structure. ast ( ) ;
116
199
117
- structure. add_bounds ( synstructure:: AddBounds :: Generics ) ;
200
+ let interner = Interner :: resolve ( & ast. generics ) ;
201
+ let traverser = gen_param ( "T" , & ast. generics ) ;
202
+ let traversable = T :: traversable ( & interner) ;
203
+
204
+ structure. underscore_const ( true ) ;
205
+ structure. add_bounds ( synstructure:: AddBounds :: None ) ;
118
206
structure. bind_with ( |_| synstructure:: BindStyle :: Move ) ;
119
207
120
- if !structure . ast ( ) . generics . lifetimes ( ) . any ( |lt| lt . lifetime . ident == "tcx" ) {
208
+ if interner . 0 . is_none ( ) {
121
209
structure. add_impl_generic ( parse_quote ! { ' tcx } ) ;
122
210
}
123
211
124
- let arms = T :: arms ( & mut structure) ;
125
- structure. bound_impl ( T :: traversable ( ) , T :: impl_body ( arms) )
212
+ // If our derived implementation will be generic over the traversable type, then we must
213
+ // constrain it to only those generic combinations that satisfy the traversable trait's
214
+ // supertraits.
215
+ if ast. generics . type_params ( ) . next ( ) . is_some ( ) {
216
+ let supertraits = T :: supertraits ( & interner) ;
217
+ structure. add_where_predicate ( parse_quote ! { Self : #supertraits } ) ;
218
+ }
219
+
220
+ // We add predicates to each generic field type, rather than to our generic type parameters.
221
+ // This results in a "perfect derive", but it can result in trait solver cycles if any type
222
+ // parameters are involved in recursive type definitions; fortunately that is not the case (yet).
223
+ let mut predicates = HashSet :: new ( ) ;
224
+ let arms = structure. each_variant ( |variant| {
225
+ let variant_ty = interner. type_of ( & variant. referenced_ty_params ( ) , variant. ast ( ) . fields ) ;
226
+ T :: arm ( variant, |bind| {
227
+ if variant_ty == Generic {
228
+ let ast = bind. ast ( ) ;
229
+ let field_ty = interner. type_of ( & bind. referenced_ty_params ( ) , [ ast] ) ;
230
+ if field_ty == Generic {
231
+ predicates. insert ( ast. ty . clone ( ) ) ;
232
+ }
233
+ }
234
+ T :: traverse ( bind. into_token_stream ( ) )
235
+ } )
236
+ } ) ;
237
+ // the order in which `where` predicates appear in rust source is irrelevant
238
+ #[ allow( rustc:: potential_query_instability) ]
239
+ for ty in predicates {
240
+ structure. add_where_predicate ( parse_quote ! { #ty: #traversable } ) ;
241
+ }
242
+ let body = quote ! { match self { #arms } } ;
243
+
244
+ structure. bound_impl ( traversable, T :: impl_body ( interner, traverser, body) )
126
245
}
0 commit comments