Skip to content

Commit 3fbc2d7

Browse files
authored
Merge pull request #228 from nnethercote/fieldless-enums-no-recursion
Avoid recursive count guard for fieldless enums
2 parents 9029685 + 88bb8e2 commit 3fbc2d7

File tree

1 file changed

+127
-60
lines changed

1 file changed

+127
-60
lines changed

derive/src/lib.rs

Lines changed: 127 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
2929
let (lifetime_without_bounds, lifetime_with_bounds) =
3030
build_arbitrary_lifetime(input.generics.clone());
3131

32+
// This won't be used if `needs_recursive_count` ends up false.
3233
let recursive_count = syn::Ident::new(
3334
&format!("RECURSIVE_COUNT_{}", input.ident),
3435
Span::call_site(),
3536
);
3637

37-
let arbitrary_method =
38+
let (arbitrary_method, needs_recursive_count) =
3839
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
39-
let size_hint_method = gen_size_hint_method(&input)?;
40+
let size_hint_method = gen_size_hint_method(&input, needs_recursive_count)?;
4041
let name = input.ident;
4142

4243
// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
@@ -56,17 +57,25 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
5657
// Build TypeGenerics and WhereClause without a lifetime
5758
let (_, ty_generics, where_clause) = generics.split_for_impl();
5859

59-
Ok(quote! {
60-
const _: () = {
60+
let recursive_count = needs_recursive_count.then(|| {
61+
Some(quote! {
6162
::std::thread_local! {
6263
#[allow(non_upper_case_globals)]
6364
static #recursive_count: ::core::cell::Cell<u32> = const {
6465
::core::cell::Cell::new(0)
6566
};
6667
}
68+
})
69+
});
70+
71+
Ok(quote! {
72+
const _: () = {
73+
#recursive_count
6774

6875
#[automatically_derived]
69-
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
76+
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds>
77+
for #name #ty_generics #where_clause
78+
{
7079
#arbitrary_method
7180
#size_hint_method
7281
}
@@ -149,10 +158,7 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
149158
generics
150159
}
151160

152-
fn with_recursive_count_guard(
153-
recursive_count: &syn::Ident,
154-
expr: impl quote::ToTokens,
155-
) -> impl quote::ToTokens {
161+
fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream {
156162
quote! {
157163
let guard_against_recursion = u.is_empty();
158164
if guard_against_recursion {
@@ -181,7 +187,7 @@ fn gen_arbitrary_method(
181187
input: &DeriveInput,
182188
lifetime: LifetimeParam,
183189
recursive_count: &syn::Ident,
184-
) -> Result<TokenStream> {
190+
) -> Result<(TokenStream, bool)> {
185191
fn arbitrary_structlike(
186192
fields: &Fields,
187193
ident: &syn::Ident,
@@ -219,28 +225,36 @@ fn gen_arbitrary_method(
219225
recursive_count: &syn::Ident,
220226
unstructured: TokenStream,
221227
variants: &[TokenStream],
222-
) -> impl quote::ToTokens {
228+
needs_recursive_count: bool,
229+
) -> TokenStream {
223230
let count = variants.len() as u64;
224-
with_recursive_count_guard(
225-
recursive_count,
226-
quote! {
227-
// Use a multiply + shift to generate a ranged random number
228-
// with slight bias. For details, see:
229-
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
230-
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
231-
#(#variants,)*
232-
_ => unreachable!()
233-
})
234-
},
235-
)
231+
232+
let do_variants = quote! {
233+
// Use a multiply + shift to generate a ranged random number
234+
// with slight bias. For details, see:
235+
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
236+
Ok(match (
237+
u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count
238+
) >> 32
239+
{
240+
#(#variants,)*
241+
_ => unreachable!()
242+
})
243+
};
244+
245+
if needs_recursive_count {
246+
with_recursive_count_guard(recursive_count, do_variants)
247+
} else {
248+
do_variants
249+
}
236250
}
237251

238252
fn arbitrary_enum(
239253
DataEnum { variants, .. }: &DataEnum,
240254
enum_name: &Ident,
241255
lifetime: LifetimeParam,
242256
recursive_count: &syn::Ident,
243-
) -> Result<TokenStream> {
257+
) -> Result<(TokenStream, bool)> {
244258
let filtered_variants = variants.iter().filter(not_skipped);
245259

246260
// Check attributes of all variants:
@@ -254,11 +268,16 @@ fn gen_arbitrary_method(
254268
.map(|(index, variant)| (index as u64, variant));
255269

256270
// Construct `match`-arms for the `arbitrary` method.
271+
let mut needs_recursive_count = false;
257272
let variants = enumerated_variants
258273
.clone()
259274
.map(|(index, Variant { fields, ident, .. })| {
260-
construct(fields, |_, field| gen_constructor_for_field(field))
261-
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
275+
construct(fields, |_, field| gen_constructor_for_field(field)).map(|ctor| {
276+
if !ctor.is_empty() {
277+
needs_recursive_count = true;
278+
}
279+
arbitrary_variant(index, enum_name, ident, ctor)
280+
})
262281
})
263282
.collect::<Result<Vec<TokenStream>>>()?;
264283

@@ -277,34 +296,56 @@ fn gen_arbitrary_method(
277296
(!variants.is_empty())
278297
.then(|| {
279298
// TODO: Improve dealing with `u` vs. `&mut u`.
280-
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
281-
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
282-
283-
quote! {
284-
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
285-
#arbitrary
286-
}
299+
let arbitrary = arbitrary_enum_method(
300+
recursive_count,
301+
quote! { u },
302+
&variants,
303+
needs_recursive_count,
304+
);
305+
let arbitrary_take_rest = arbitrary_enum_method(
306+
recursive_count,
307+
quote! { &mut u },
308+
&variants_take_rest,
309+
needs_recursive_count,
310+
);
311+
312+
(
313+
quote! {
314+
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>)
315+
-> arbitrary::Result<Self>
316+
{
317+
#arbitrary
318+
}
287319

288-
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
289-
#arbitrary_take_rest
290-
}
291-
}
320+
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>)
321+
-> arbitrary::Result<Self>
322+
{
323+
#arbitrary_take_rest
324+
}
325+
},
326+
needs_recursive_count,
327+
)
328+
})
329+
.ok_or_else(|| {
330+
Error::new_spanned(
331+
enum_name,
332+
"Enum must have at least one variant, that is not skipped",
333+
)
292334
})
293-
.ok_or_else(|| Error::new_spanned(
294-
enum_name,
295-
"Enum must have at least one variant, that is not skipped"
296-
))
297335
}
298336

299337
let ident = &input.ident;
338+
let needs_recursive_count = true;
300339
match &input.data {
301-
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
340+
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)
341+
.map(|ts| (ts, needs_recursive_count)),
302342
Data::Union(data) => arbitrary_structlike(
303343
&Fields::Named(data.fields.clone()),
304344
ident,
305345
lifetime,
306346
recursive_count,
307-
),
347+
)
348+
.map(|ts| (ts, needs_recursive_count)),
308349
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
309350
}
310351
}
@@ -357,7 +398,7 @@ fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
357398
})
358399
}
359400

360-
fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
401+
fn gen_size_hint_method(input: &DeriveInput, needs_recursive_count: bool) -> Result<TokenStream> {
361402
let size_hint_fields = |fields: &Fields| {
362403
fields
363404
.iter()
@@ -372,9 +413,9 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
372413
quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
373414
}
374415

375-
// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
376-
// just an educated guess, although it's gonna be inaccurate for dynamically
377-
// allocated types (Vec, HashMap, etc.).
416+
// Note that in this case it's hard to determine what size_hint must be, so
417+
// size_of::<T>() is just an educated guess, although it's gonna be
418+
// inaccurate for dynamically allocated types (Vec, HashMap, etc.).
378419
FieldConstructor::With(_) => {
379420
quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
380421
}
@@ -391,6 +432,7 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
391432
})
392433
};
393434
let size_hint_structlike = |fields: &Fields| {
435+
assert!(needs_recursive_count);
394436
size_hint_fields(fields).map(|hint| {
395437
quote! {
396438
#[inline]
@@ -399,7 +441,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
399441
}
400442

401443
#[inline]
402-
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
444+
fn try_size_hint(depth: usize)
445+
-> ::core::result::Result<
446+
(usize, ::core::option::Option<usize>),
447+
arbitrary::MaxRecursionReached,
448+
>
449+
{
403450
arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
404451
}
405452
}
@@ -413,24 +460,44 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
413460
.iter()
414461
.filter(not_skipped)
415462
.map(|Variant { fields, .. }| {
463+
if !needs_recursive_count {
464+
assert!(fields.is_empty());
465+
}
416466
// The attributes of all variants are checked in `gen_arbitrary_method` above
417-
// and can therefore assume that they are valid.
467+
// and can therefore assume that they are valid.
418468
size_hint_fields(fields)
419469
})
420470
.collect::<Result<Vec<TokenStream>>>()
421471
.map(|variants| {
422-
quote! {
423-
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
424-
Self::try_size_hint(depth).unwrap_or_default()
472+
if needs_recursive_count {
473+
// The enum might be recursive: `try_size_hint` is the primary one, and
474+
// `size_hint` is defined in terms of it.
475+
quote! {
476+
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
477+
Self::try_size_hint(depth).unwrap_or_default()
478+
}
479+
#[inline]
480+
fn try_size_hint(depth: usize)
481+
-> ::core::result::Result<
482+
(usize, ::core::option::Option<usize>),
483+
arbitrary::MaxRecursionReached,
484+
>
485+
{
486+
Ok(arbitrary::size_hint::and(
487+
<u32 as arbitrary::Arbitrary>::size_hint(depth),
488+
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
489+
Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
490+
})?,
491+
))
492+
}
425493
}
426-
#[inline]
427-
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
428-
Ok(arbitrary::size_hint::and(
429-
<u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
430-
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
431-
Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
432-
})?,
433-
))
494+
} else {
495+
// The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the
496+
// primary one, and the default `try_size_hint` is good enough.
497+
quote! {
498+
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
499+
<u32 as arbitrary::Arbitrary>::size_hint(depth)
500+
}
434501
}
435502
}
436503
}),

0 commit comments

Comments
 (0)