@@ -29,14 +29,15 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
29
29
let ( lifetime_without_bounds, lifetime_with_bounds) =
30
30
build_arbitrary_lifetime ( input. generics . clone ( ) ) ;
31
31
32
+ // This won't be used if `needs_recursive_count` ends up false.
32
33
let recursive_count = syn:: Ident :: new (
33
34
& format ! ( "RECURSIVE_COUNT_{}" , input. ident) ,
34
35
Span :: call_site ( ) ,
35
36
) ;
36
37
37
- let arbitrary_method =
38
+ let ( arbitrary_method, needs_recursive_count ) =
38
39
gen_arbitrary_method ( & input, lifetime_without_bounds. clone ( ) , & recursive_count) ?;
39
- let size_hint_method = gen_size_hint_method ( & input) ?;
40
+ let size_hint_method = gen_size_hint_method ( & input, needs_recursive_count ) ?;
40
41
let name = input. ident ;
41
42
42
43
// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
@@ -56,17 +57,25 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
56
57
// Build TypeGenerics and WhereClause without a lifetime
57
58
let ( _, ty_generics, where_clause) = generics. split_for_impl ( ) ;
58
59
59
- Ok ( quote ! {
60
- const _ : ( ) = {
60
+ let recursive_count = needs_recursive_count . then ( || {
61
+ Some ( quote ! {
61
62
:: std:: thread_local! {
62
63
#[ allow( non_upper_case_globals) ]
63
64
static #recursive_count: :: core:: cell:: Cell <u32 > = const {
64
65
:: core:: cell:: Cell :: new( 0 )
65
66
} ;
66
67
}
68
+ } )
69
+ } ) ;
70
+
71
+ Ok ( quote ! {
72
+ const _: ( ) = {
73
+ #recursive_count
67
74
68
75
#[ automatically_derived]
69
- impl #impl_generics arbitrary:: Arbitrary <#lifetime_without_bounds> for #name #ty_generics #where_clause {
76
+ impl #impl_generics arbitrary:: Arbitrary <#lifetime_without_bounds>
77
+ for #name #ty_generics #where_clause
78
+ {
70
79
#arbitrary_method
71
80
#size_hint_method
72
81
}
@@ -149,10 +158,7 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
149
158
generics
150
159
}
151
160
152
- fn with_recursive_count_guard (
153
- recursive_count : & syn:: Ident ,
154
- expr : impl quote:: ToTokens ,
155
- ) -> impl quote:: ToTokens {
161
+ fn with_recursive_count_guard ( recursive_count : & syn:: Ident , expr : TokenStream ) -> TokenStream {
156
162
quote ! {
157
163
let guard_against_recursion = u. is_empty( ) ;
158
164
if guard_against_recursion {
@@ -181,7 +187,7 @@ fn gen_arbitrary_method(
181
187
input : & DeriveInput ,
182
188
lifetime : LifetimeParam ,
183
189
recursive_count : & syn:: Ident ,
184
- ) -> Result < TokenStream > {
190
+ ) -> Result < ( TokenStream , bool ) > {
185
191
fn arbitrary_structlike (
186
192
fields : & Fields ,
187
193
ident : & syn:: Ident ,
@@ -219,28 +225,36 @@ fn gen_arbitrary_method(
219
225
recursive_count : & syn:: Ident ,
220
226
unstructured : TokenStream ,
221
227
variants : & [ TokenStream ] ,
222
- ) -> impl quote:: ToTokens {
228
+ needs_recursive_count : bool ,
229
+ ) -> TokenStream {
223
230
let count = variants. len ( ) as u64 ;
224
- with_recursive_count_guard (
225
- recursive_count,
226
- quote ! {
227
- // Use a multiply + shift to generate a ranged random number
228
- // with slight bias. For details, see:
229
- // https://lemire.me/blog/2016/06/30/fast-random-shuffling
230
- Ok ( match ( u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( #unstructured) ?) * #count) >> 32 {
231
- #( #variants, ) *
232
- _ => unreachable!( )
233
- } )
234
- } ,
235
- )
231
+
232
+ let do_variants = quote ! {
233
+ // Use a multiply + shift to generate a ranged random number
234
+ // with slight bias. For details, see:
235
+ // https://lemire.me/blog/2016/06/30/fast-random-shuffling
236
+ Ok ( match (
237
+ u64 :: from( <u32 as arbitrary:: Arbitrary >:: arbitrary( #unstructured) ?) * #count
238
+ ) >> 32
239
+ {
240
+ #( #variants, ) *
241
+ _ => unreachable!( )
242
+ } )
243
+ } ;
244
+
245
+ if needs_recursive_count {
246
+ with_recursive_count_guard ( recursive_count, do_variants)
247
+ } else {
248
+ do_variants
249
+ }
236
250
}
237
251
238
252
fn arbitrary_enum (
239
253
DataEnum { variants, .. } : & DataEnum ,
240
254
enum_name : & Ident ,
241
255
lifetime : LifetimeParam ,
242
256
recursive_count : & syn:: Ident ,
243
- ) -> Result < TokenStream > {
257
+ ) -> Result < ( TokenStream , bool ) > {
244
258
let filtered_variants = variants. iter ( ) . filter ( not_skipped) ;
245
259
246
260
// Check attributes of all variants:
@@ -254,11 +268,16 @@ fn gen_arbitrary_method(
254
268
. map ( |( index, variant) | ( index as u64 , variant) ) ;
255
269
256
270
// Construct `match`-arms for the `arbitrary` method.
271
+ let mut needs_recursive_count = false ;
257
272
let variants = enumerated_variants
258
273
. clone ( )
259
274
. map ( |( index, Variant { fields, ident, .. } ) | {
260
- construct ( fields, |_, field| gen_constructor_for_field ( field) )
261
- . map ( |ctor| arbitrary_variant ( index, enum_name, ident, ctor) )
275
+ construct ( fields, |_, field| gen_constructor_for_field ( field) ) . map ( |ctor| {
276
+ if !ctor. is_empty ( ) {
277
+ needs_recursive_count = true ;
278
+ }
279
+ arbitrary_variant ( index, enum_name, ident, ctor)
280
+ } )
262
281
} )
263
282
. collect :: < Result < Vec < TokenStream > > > ( ) ?;
264
283
@@ -277,34 +296,56 @@ fn gen_arbitrary_method(
277
296
( !variants. is_empty ( ) )
278
297
. then ( || {
279
298
// TODO: Improve dealing with `u` vs. `&mut u`.
280
- let arbitrary = arbitrary_enum_method ( recursive_count, quote ! { u } , & variants) ;
281
- let arbitrary_take_rest = arbitrary_enum_method ( recursive_count, quote ! { & mut u } , & variants_take_rest) ;
282
-
283
- quote ! {
284
- fn arbitrary( u: & mut arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
285
- #arbitrary
286
- }
299
+ let arbitrary = arbitrary_enum_method (
300
+ recursive_count,
301
+ quote ! { u } ,
302
+ & variants,
303
+ needs_recursive_count,
304
+ ) ;
305
+ let arbitrary_take_rest = arbitrary_enum_method (
306
+ recursive_count,
307
+ quote ! { & mut u } ,
308
+ & variants_take_rest,
309
+ needs_recursive_count,
310
+ ) ;
311
+
312
+ (
313
+ quote ! {
314
+ fn arbitrary( u: & mut arbitrary:: Unstructured <#lifetime>)
315
+ -> arbitrary:: Result <Self >
316
+ {
317
+ #arbitrary
318
+ }
287
319
288
- fn arbitrary_take_rest( mut u: arbitrary:: Unstructured <#lifetime>) -> arbitrary:: Result <Self > {
289
- #arbitrary_take_rest
290
- }
291
- }
320
+ fn arbitrary_take_rest( mut u: arbitrary:: Unstructured <#lifetime>)
321
+ -> arbitrary:: Result <Self >
322
+ {
323
+ #arbitrary_take_rest
324
+ }
325
+ } ,
326
+ needs_recursive_count,
327
+ )
328
+ } )
329
+ . ok_or_else ( || {
330
+ Error :: new_spanned (
331
+ enum_name,
332
+ "Enum must have at least one variant, that is not skipped" ,
333
+ )
292
334
} )
293
- . ok_or_else ( || Error :: new_spanned (
294
- enum_name,
295
- "Enum must have at least one variant, that is not skipped"
296
- ) )
297
335
}
298
336
299
337
let ident = & input. ident ;
338
+ let needs_recursive_count = true ;
300
339
match & input. data {
301
- Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count) ,
340
+ Data :: Struct ( data) => arbitrary_structlike ( & data. fields , ident, lifetime, recursive_count)
341
+ . map ( |ts| ( ts, needs_recursive_count) ) ,
302
342
Data :: Union ( data) => arbitrary_structlike (
303
343
& Fields :: Named ( data. fields . clone ( ) ) ,
304
344
ident,
305
345
lifetime,
306
346
recursive_count,
307
- ) ,
347
+ )
348
+ . map ( |ts| ( ts, needs_recursive_count) ) ,
308
349
Data :: Enum ( data) => arbitrary_enum ( data, ident, lifetime, recursive_count) ,
309
350
}
310
351
}
@@ -357,7 +398,7 @@ fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
357
398
} )
358
399
}
359
400
360
- fn gen_size_hint_method ( input : & DeriveInput ) -> Result < TokenStream > {
401
+ fn gen_size_hint_method ( input : & DeriveInput , needs_recursive_count : bool ) -> Result < TokenStream > {
361
402
let size_hint_fields = |fields : & Fields | {
362
403
fields
363
404
. iter ( )
@@ -372,9 +413,9 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
372
413
quote ! { <#ty as arbitrary:: Arbitrary >:: try_size_hint( depth) }
373
414
}
374
415
375
- // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
376
- // just an educated guess, although it's gonna be inaccurate for dynamically
377
- // allocated types (Vec, HashMap, etc.).
416
+ // Note that in this case it's hard to determine what size_hint must be, so
417
+ // size_of::<T>() is just an educated guess, although it's gonna be
418
+ // inaccurate for dynamically allocated types (Vec, HashMap, etc.).
378
419
FieldConstructor :: With ( _) => {
379
420
quote ! { Ok ( ( :: core:: mem:: size_of:: <#ty>( ) , None ) ) }
380
421
}
@@ -391,6 +432,7 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
391
432
} )
392
433
} ;
393
434
let size_hint_structlike = |fields : & Fields | {
435
+ assert ! ( needs_recursive_count) ;
394
436
size_hint_fields ( fields) . map ( |hint| {
395
437
quote ! {
396
438
#[ inline]
@@ -399,7 +441,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
399
441
}
400
442
401
443
#[ inline]
402
- fn try_size_hint( depth: usize ) -> :: core:: result:: Result <( usize , :: core:: option:: Option <usize >) , arbitrary:: MaxRecursionReached > {
444
+ fn try_size_hint( depth: usize )
445
+ -> :: core:: result:: Result <
446
+ ( usize , :: core:: option:: Option <usize >) ,
447
+ arbitrary:: MaxRecursionReached ,
448
+ >
449
+ {
403
450
arbitrary:: size_hint:: try_recursion_guard( depth, |depth| #hint)
404
451
}
405
452
}
@@ -413,24 +460,44 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
413
460
. iter ( )
414
461
. filter ( not_skipped)
415
462
. map ( |Variant { fields, .. } | {
463
+ if !needs_recursive_count {
464
+ assert ! ( fields. is_empty( ) ) ;
465
+ }
416
466
// The attributes of all variants are checked in `gen_arbitrary_method` above
417
- // and can therefore assume that they are valid.
467
+ // and can therefore assume that they are valid.
418
468
size_hint_fields ( fields)
419
469
} )
420
470
. collect :: < Result < Vec < TokenStream > > > ( )
421
471
. map ( |variants| {
422
- quote ! {
423
- fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
424
- Self :: try_size_hint( depth) . unwrap_or_default( )
472
+ if needs_recursive_count {
473
+ // The enum might be recursive: `try_size_hint` is the primary one, and
474
+ // `size_hint` is defined in terms of it.
475
+ quote ! {
476
+ fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
477
+ Self :: try_size_hint( depth) . unwrap_or_default( )
478
+ }
479
+ #[ inline]
480
+ fn try_size_hint( depth: usize )
481
+ -> :: core:: result:: Result <
482
+ ( usize , :: core:: option:: Option <usize >) ,
483
+ arbitrary:: MaxRecursionReached ,
484
+ >
485
+ {
486
+ Ok ( arbitrary:: size_hint:: and(
487
+ <u32 as arbitrary:: Arbitrary >:: size_hint( depth) ,
488
+ arbitrary:: size_hint:: try_recursion_guard( depth, |depth| {
489
+ Ok ( arbitrary:: size_hint:: or_all( & [ #( #variants? ) , * ] ) )
490
+ } ) ?,
491
+ ) )
492
+ }
425
493
}
426
- #[ inline]
427
- fn try_size_hint( depth: usize ) -> :: core:: result:: Result <( usize , :: core:: option:: Option <usize >) , arbitrary:: MaxRecursionReached > {
428
- Ok ( arbitrary:: size_hint:: and(
429
- <u32 as arbitrary:: Arbitrary >:: try_size_hint( depth) ?,
430
- arbitrary:: size_hint:: try_recursion_guard( depth, |depth| {
431
- Ok ( arbitrary:: size_hint:: or_all( & [ #( #variants? ) , * ] ) )
432
- } ) ?,
433
- ) )
494
+ } else {
495
+ // The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the
496
+ // primary one, and the default `try_size_hint` is good enough.
497
+ quote ! {
498
+ fn size_hint( depth: usize ) -> ( usize , :: core:: option:: Option <usize >) {
499
+ <u32 as arbitrary:: Arbitrary >:: size_hint( depth)
500
+ }
434
501
}
435
502
}
436
503
} ) ,
0 commit comments