diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 86701ca..066e58c 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -29,14 +29,15 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result { let (lifetime_without_bounds, lifetime_with_bounds) = build_arbitrary_lifetime(input.generics.clone()); + // This won't be used if `needs_recursive_count` ends up false. let recursive_count = syn::Ident::new( &format!("RECURSIVE_COUNT_{}", input.ident), Span::call_site(), ); - let arbitrary_method = + let (arbitrary_method, needs_recursive_count) = gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?; - let size_hint_method = gen_size_hint_method(&input)?; + let size_hint_method = gen_size_hint_method(&input, needs_recursive_count)?; let name = input.ident; // Apply user-supplied bounds or automatic `T: ArbitraryBounds`. @@ -56,17 +57,25 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result { // Build TypeGenerics and WhereClause without a lifetime let (_, ty_generics, where_clause) = generics.split_for_impl(); - Ok(quote! { - const _: () = { + let recursive_count = needs_recursive_count.then(|| { + Some(quote! { ::std::thread_local! { #[allow(non_upper_case_globals)] static #recursive_count: ::core::cell::Cell = const { ::core::cell::Cell::new(0) }; } + }) + }); + + Ok(quote! { + const _: () = { + #recursive_count #[automatically_derived] - impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { + impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> + for #name #ty_generics #where_clause + { #arbitrary_method #size_hint_method } @@ -149,10 +158,7 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics generics } -fn with_recursive_count_guard( - recursive_count: &syn::Ident, - expr: impl quote::ToTokens, -) -> impl quote::ToTokens { +fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream { quote! { let guard_against_recursion = u.is_empty(); if guard_against_recursion { @@ -181,7 +187,7 @@ fn gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, recursive_count: &syn::Ident, -) -> Result { +) -> Result<(TokenStream, bool)> { fn arbitrary_structlike( fields: &Fields, ident: &syn::Ident, @@ -219,20 +225,28 @@ fn gen_arbitrary_method( recursive_count: &syn::Ident, unstructured: TokenStream, variants: &[TokenStream], - ) -> impl quote::ToTokens { + needs_recursive_count: bool, + ) -> TokenStream { let count = variants.len() as u64; - with_recursive_count_guard( - recursive_count, - quote! { - // Use a multiply + shift to generate a ranged random number - // with slight bias. For details, see: - // https://lemire.me/blog/2016/06/30/fast-random-shuffling - Ok(match (u64::from(::arbitrary(#unstructured)?) * #count) >> 32 { - #(#variants,)* - _ => unreachable!() - }) - }, - ) + + let do_variants = quote! { + // Use a multiply + shift to generate a ranged random number + // with slight bias. For details, see: + // https://lemire.me/blog/2016/06/30/fast-random-shuffling + Ok(match ( + u64::from(::arbitrary(#unstructured)?) * #count + ) >> 32 + { + #(#variants,)* + _ => unreachable!() + }) + }; + + if needs_recursive_count { + with_recursive_count_guard(recursive_count, do_variants) + } else { + do_variants + } } fn arbitrary_enum( @@ -240,7 +254,7 @@ fn gen_arbitrary_method( enum_name: &Ident, lifetime: LifetimeParam, recursive_count: &syn::Ident, - ) -> Result { + ) -> Result<(TokenStream, bool)> { let filtered_variants = variants.iter().filter(not_skipped); // Check attributes of all variants: @@ -254,11 +268,16 @@ fn gen_arbitrary_method( .map(|(index, variant)| (index as u64, variant)); // Construct `match`-arms for the `arbitrary` method. + let mut needs_recursive_count = false; let variants = enumerated_variants .clone() .map(|(index, Variant { fields, ident, .. })| { - construct(fields, |_, field| gen_constructor_for_field(field)) - .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor)) + construct(fields, |_, field| gen_constructor_for_field(field)).map(|ctor| { + if !ctor.is_empty() { + needs_recursive_count = true; + } + arbitrary_variant(index, enum_name, ident, ctor) + }) }) .collect::>>()?; @@ -277,34 +296,56 @@ fn gen_arbitrary_method( (!variants.is_empty()) .then(|| { // TODO: Improve dealing with `u` vs. `&mut u`. - let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants); - let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest); - - quote! { - fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - #arbitrary - } + let arbitrary = arbitrary_enum_method( + recursive_count, + quote! { u }, + &variants, + needs_recursive_count, + ); + let arbitrary_take_rest = arbitrary_enum_method( + recursive_count, + quote! { &mut u }, + &variants_take_rest, + needs_recursive_count, + ); + + ( + quote! { + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) + -> arbitrary::Result + { + #arbitrary + } - fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - #arbitrary_take_rest - } - } + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) + -> arbitrary::Result + { + #arbitrary_take_rest + } + }, + needs_recursive_count, + ) + }) + .ok_or_else(|| { + Error::new_spanned( + enum_name, + "Enum must have at least one variant, that is not skipped", + ) }) - .ok_or_else(|| Error::new_spanned( - enum_name, - "Enum must have at least one variant, that is not skipped" - )) } let ident = &input.ident; + let needs_recursive_count = true; match &input.data { - Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count), + Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count) + .map(|ts| (ts, needs_recursive_count)), Data::Union(data) => arbitrary_structlike( &Fields::Named(data.fields.clone()), ident, lifetime, recursive_count, - ), + ) + .map(|ts| (ts, needs_recursive_count)), Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count), } } @@ -357,7 +398,7 @@ fn construct_take_rest(fields: &Fields) -> Result { }) } -fn gen_size_hint_method(input: &DeriveInput) -> Result { +fn gen_size_hint_method(input: &DeriveInput, needs_recursive_count: bool) -> Result { let size_hint_fields = |fields: &Fields| { fields .iter() @@ -372,9 +413,9 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result { quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) } } - // Note that in this case it's hard to determine what size_hint must be, so size_of::() is - // just an educated guess, although it's gonna be inaccurate for dynamically - // allocated types (Vec, HashMap, etc.). + // Note that in this case it's hard to determine what size_hint must be, so + // size_of::() is just an educated guess, although it's gonna be + // inaccurate for dynamically allocated types (Vec, HashMap, etc.). FieldConstructor::With(_) => { quote! { Ok((::core::mem::size_of::<#ty>(), None)) } } @@ -391,6 +432,7 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result { }) }; let size_hint_structlike = |fields: &Fields| { + assert!(needs_recursive_count); size_hint_fields(fields).map(|hint| { quote! { #[inline] @@ -399,7 +441,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result { } #[inline] - fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option), arbitrary::MaxRecursionReached> { + fn try_size_hint(depth: usize) + -> ::core::result::Result< + (usize, ::core::option::Option), + arbitrary::MaxRecursionReached, + > + { arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint) } } @@ -413,24 +460,44 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result { .iter() .filter(not_skipped) .map(|Variant { fields, .. }| { + if !needs_recursive_count { + assert!(fields.is_empty()); + } // The attributes of all variants are checked in `gen_arbitrary_method` above - // and can therefore assume that they are valid. + // and can therefore assume that they are valid. size_hint_fields(fields) }) .collect::>>() .map(|variants| { - quote! { - fn size_hint(depth: usize) -> (usize, ::core::option::Option) { - Self::try_size_hint(depth).unwrap_or_default() + if needs_recursive_count { + // The enum might be recursive: `try_size_hint` is the primary one, and + // `size_hint` is defined in terms of it. + quote! { + fn size_hint(depth: usize) -> (usize, ::core::option::Option) { + Self::try_size_hint(depth).unwrap_or_default() + } + #[inline] + fn try_size_hint(depth: usize) + -> ::core::result::Result< + (usize, ::core::option::Option), + arbitrary::MaxRecursionReached, + > + { + Ok(arbitrary::size_hint::and( + ::size_hint(depth), + arbitrary::size_hint::try_recursion_guard(depth, |depth| { + Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ])) + })?, + )) + } } - #[inline] - fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option), arbitrary::MaxRecursionReached> { - Ok(arbitrary::size_hint::and( - ::try_size_hint(depth)?, - arbitrary::size_hint::try_recursion_guard(depth, |depth| { - Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ])) - })?, - )) + } else { + // The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the + // primary one, and the default `try_size_hint` is good enough. + quote! { + fn size_hint(depth: usize) -> (usize, ::core::option::Option) { + ::size_hint(depth) + } } } }),