@@ -6,11 +6,14 @@ use syn::*;
6
6
7
7
mod container_attributes;
8
8
mod field_attributes;
9
+ mod variant_attributes;
10
+
9
11
use container_attributes:: ContainerAttributes ;
10
12
use field_attributes:: { determine_field_constructor, FieldConstructor } ;
13
+ use variant_attributes:: not_skipped;
11
14
12
- static ARBITRARY_ATTRIBUTE_NAME : & str = "arbitrary" ;
13
- static ARBITRARY_LIFETIME_NAME : & str = "'arbitrary" ;
15
+ const ARBITRARY_ATTRIBUTE_NAME : & str = "arbitrary" ;
16
+ const ARBITRARY_LIFETIME_NAME : & str = "'arbitrary" ;
14
17
15
18
#[ proc_macro_derive( Arbitrary , attributes( arbitrary) ) ]
16
19
pub fn derive_arbitrary ( tokens : proc_macro:: TokenStream ) -> proc_macro:: TokenStream {
@@ -201,81 +204,107 @@ fn gen_arbitrary_method(
201
204
} )
202
205
}
203
206
204
- let ident = & input. ident ;
205
- let output = match & input. data {
206
- Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count) ?,
207
- Data :: Union ( data) => arbitrary_structlike (
208
- & Fields :: Named ( data. fields . clone ( ) ) ,
209
- ident,
210
- lifetime,
207
+ fn arbitrary_variant (
208
+ index : u64 ,
209
+ enum_name : & Ident ,
210
+ variant_name : & Ident ,
211
+ ctor : TokenStream ,
212
+ ) -> TokenStream {
213
+ quote ! { #index => #enum_name:: #variant_name #ctor }
214
+ }
215
+
216
+ fn arbitrary_enum_method (
217
+ recursive_count : & syn:: Ident ,
218
+ unstructured : TokenStream ,
219
+ variants : & [ TokenStream ] ,
220
+ ) -> impl quote:: ToTokens {
221
+ let count = variants. len ( ) as u64 ;
222
+ with_recursive_count_guard (
211
223
recursive_count,
212
- ) ?,
213
- Data :: Enum ( data) => {
214
- let variants: Vec < TokenStream > = data
215
- . variants
216
- . iter ( )
217
- . enumerate ( )
218
- . map ( |( i, variant) | {
219
- check_variant_attrs ( variant) ?;
220
- let idx = i as u64 ;
221
- let variant_name = & variant. ident ;
222
- construct ( & variant. fields , |_, field| gen_constructor_for_field ( field) )
223
- . map ( |ctor| quote ! { #idx => #ident:: #variant_name #ctor } )
224
+ quote ! {
225
+ // Use a multiply + shift to generate a ranged random number
226
+ // with slight bias. For details, see:
227
+ // https://lemire.me/blog/2016/06/30/fast-random-shuffling
228
+ Ok ( match ( u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( #unstructured) ?) * #count) >> 32 {
229
+ #( #variants, ) *
230
+ _ => unreachable!( )
224
231
} )
225
- . collect :: < Result < _ > > ( ) ?;
232
+ } ,
233
+ )
234
+ }
226
235
227
- let variants_take_rest: Vec < TokenStream > = data
228
- . variants
229
- . iter ( )
230
- . enumerate ( )
231
- . map ( |( i, variant) | {
232
- let idx = i as u64 ;
233
- let variant_name = & variant. ident ;
234
- construct_take_rest ( & variant. fields )
235
- . map ( |ctor| quote ! { #idx => #ident:: #variant_name #ctor } )
236
- } )
237
- . collect :: < Result < _ > > ( ) ?;
236
+ fn arbitrary_enum (
237
+ DataEnum { variants, .. } : & DataEnum ,
238
+ enum_name : & Ident ,
239
+ lifetime : LifetimeParam ,
240
+ recursive_count : & syn:: Ident ,
241
+ ) -> Result < TokenStream > {
242
+ let filtered_variants = variants. iter ( ) . filter ( not_skipped) ;
243
+
244
+ // Check attributes of all variants:
245
+ filtered_variants
246
+ . clone ( )
247
+ . try_for_each ( check_variant_attrs) ?;
248
+
249
+ // From here on, we can assume that the attributes of all variants were checked.
250
+ let enumerated_variants = filtered_variants
251
+ . enumerate ( )
252
+ . map ( |( index, variant) | ( index as u64 , variant) ) ;
253
+
254
+ // Construct `match`-arms for the `arbitrary` method.
255
+ let variants = enumerated_variants
256
+ . clone ( )
257
+ . map ( |( index, Variant { fields, ident, .. } ) | {
258
+ construct ( fields, |_, field| gen_constructor_for_field ( field) )
259
+ . map ( |ctor| arbitrary_variant ( index, enum_name, ident, ctor) )
260
+ } )
261
+ . collect :: < Result < Vec < TokenStream > > > ( ) ?;
238
262
239
- let count = data. variants . len ( ) as u64 ;
263
+ // Construct `match`-arms for the `arbitrary_take_rest` method.
264
+ let variants_take_rest = enumerated_variants
265
+ . map ( |( index, Variant { fields, ident, .. } ) | {
266
+ construct_take_rest ( fields)
267
+ . map ( |ctor| arbitrary_variant ( index, enum_name, ident, ctor) )
268
+ } )
269
+ . collect :: < Result < Vec < TokenStream > > > ( ) ?;
270
+
271
+ // Most of the time, `variants` is not empty (the happy path),
272
+ // thus `variants_take_rest` will be used,
273
+ // so no need to move this check before constructing `variants_take_rest`.
274
+ // If `variants` is empty, this will emit a compiler-error.
275
+ ( !variants. is_empty ( ) )
276
+ . then ( || {
277
+ // TODO: Improve dealing with `u` vs. `&mut u`.
278
+ let arbitrary = arbitrary_enum_method ( recursive_count, quote ! { u } , & variants) ;
279
+ let arbitrary_take_rest = arbitrary_enum_method ( recursive_count, quote ! { & mut u } , & variants_take_rest) ;
240
280
241
- let arbitrary = with_recursive_count_guard (
242
- recursive_count,
243
- quote ! {
244
- // Use a multiply + shift to generate a ranged random number
245
- // with slight bias. For details, see:
246
- // https://lemire.me/blog/2016/06/30/fast-random-shuffling
247
- Ok ( match ( u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( u) ?) * #count) >> 32 {
248
- #( #variants, ) *
249
- _ => unreachable!( )
250
- } )
251
- } ,
252
- ) ;
253
-
254
- let arbitrary_take_rest = with_recursive_count_guard (
255
- recursive_count,
256
281
quote ! {
257
- // Use a multiply + shift to generate a ranged random number
258
- // with slight bias. For details, see:
259
- // https://lemire.me/blog/2016/06/30/fast-random-shuffling
260
- Ok ( match ( u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( & mut u) ?) * #count) >> 32 {
261
- #( #variants_take_rest, ) *
262
- _ => unreachable!( )
263
- } )
264
- } ,
265
- ) ;
282
+ fn arbitrary( u: & mut arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
283
+ #arbitrary
284
+ }
266
285
267
- quote ! {
268
- fn arbitrary ( u : & mut arbitrary :: Unstructured <#lifetime> ) -> arbitrary :: Result < Self > {
269
- #arbitrary
286
+ fn arbitrary_take_rest ( mut u : arbitrary :: Unstructured <#lifetime> ) -> arbitrary :: Result < Self > {
287
+ #arbitrary_take_rest
288
+ }
270
289
}
290
+ } )
291
+ . ok_or_else ( || Error :: new_spanned (
292
+ enum_name,
293
+ "Enum must have at least one variant, that is not skipped"
294
+ ) )
295
+ }
271
296
272
- fn arbitrary_take_rest( mut u: arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
273
- #arbitrary_take_rest
274
- }
275
- }
276
- }
277
- } ;
278
- Ok ( output)
297
+ let ident = & input. ident ;
298
+ match & input. data {
299
+ Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count) ,
300
+ Data :: Union ( data) => arbitrary_structlike (
301
+ & Fields :: Named ( data. fields . clone ( ) ) ,
302
+ ident,
303
+ lifetime,
304
+ recursive_count,
305
+ ) ,
306
+ Data :: Enum ( data) => arbitrary_enum ( data, ident, lifetime, recursive_count) ,
307
+ }
279
308
}
280
309
281
310
fn construct (
@@ -375,7 +404,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
375
404
Data :: Enum ( data) => data
376
405
. variants
377
406
. iter ( )
378
- . map ( |v| size_hint_fields ( & v. fields ) )
407
+ . filter ( not_skipped)
408
+ . map ( |Variant { fields, .. } | {
409
+ // The attributes of all variants are checked in `gen_arbitrary_method` above
410
+ // and can therefore assume that they are valid.
411
+ size_hint_fields ( fields)
412
+ } )
379
413
. collect :: < Result < Vec < TokenStream > > > ( )
380
414
. map ( |variants| {
381
415
quote ! {
0 commit comments