Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ rust-version = "1.63.0" # Keep in sync with version documented in the README.md
derive_arbitrary = { version = "~1.4.0", path = "./derive", optional = true }

[features]
default = ["std"]
default = ["std", "recursive_count"]
# Turn this feature on to enable support for `#[derive(Arbitrary)]`.
derive = ["derive_arbitrary"]
# Enables support for the `std` crate.
Expand All @@ -36,6 +36,8 @@ core_error = []
core_net = []
# Enables using `alloc::ffi::CString` when `std` is disabled. Increases MSRV to at least 1.64.0
alloc_ffi_cstring = ["alloc"]
# Checks for unbounded recursion at runtime. This does nothing without the "derive" feature.
recursive_count = ["derive_arbitrary?/recursive_count"]

[[example]]
name = "derive_enum"
Expand Down
4 changes: 4 additions & 0 deletions derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ syn = { version = "2", features = ['derive', 'parsing', 'extra-traits'] }

[lib]
proc-macro = true

[features]
# Checks for unbounded recursion at runtime. Requires `std`.
recursive_count = []
88 changes: 68 additions & 20 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

#[cfg(feature = "recursive_count")]
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);

let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let arbitrary_method = gen_arbitrary_method(
&input,
lifetime_without_bounds.clone(),
#[cfg(feature = "recursive_count")]
&recursive_count,
)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;

Expand All @@ -56,14 +61,22 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();

#[cfg(feature = "recursive_count")]
let recursive_count_preamble = quote! {
extern crate std;

::std::thread_local! {
#[allow(non_upper_case_globals)]
static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
}
};

#[cfg(not(feature = "recursive_count"))]
let recursive_count_preamble = TokenStream::new();

Ok(quote! {
const _: () = {
extern crate std;

::std::thread_local! {
#[allow(non_upper_case_globals)]
static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
}
#recursive_count_preamble

#[automatically_derived]
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
Expand Down Expand Up @@ -150,9 +163,10 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
}

fn with_recursive_count_guard(
recursive_count: &syn::Ident,
#[cfg(feature = "recursive_count")] recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
#[cfg(feature = "recursive_count")]
quote! {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
Expand All @@ -175,25 +189,35 @@ fn with_recursive_count_guard(

result
}

#[cfg(not(feature = "recursive_count"))]
quote! { (|| { #expr })() }
}

fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
#[cfg(feature = "recursive_count")] recursive_count: &syn::Ident,
) -> Result<TokenStream> {
fn arbitrary_structlike(
fields: &Fields,
ident: &syn::Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
#[cfg(feature = "recursive_count")] recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
let body = with_recursive_count_guard(
#[cfg(feature = "recursive_count")]
recursive_count,
quote! { Ok(#ident #arbitrary) },
);

let arbitrary_take_rest = construct_take_rest(fields)?;
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
let take_rest_body = with_recursive_count_guard(
#[cfg(feature = "recursive_count")]
recursive_count,
quote! { Ok(#ident #arbitrary_take_rest) },
);

Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Expand All @@ -216,12 +240,13 @@ fn gen_arbitrary_method(
}

fn arbitrary_enum_method(
recursive_count: &syn::Ident,
unstructured: TokenStream,
variants: &[TokenStream],
#[cfg(feature = "recursive_count")] recursive_count: &syn::Ident,
) -> impl quote::ToTokens {
let count = variants.len() as u64;
with_recursive_count_guard(
#[cfg(feature = "recursive_count")]
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
Expand All @@ -239,7 +264,7 @@ fn gen_arbitrary_method(
DataEnum { variants, .. }: &DataEnum,
enum_name: &Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
#[cfg(feature = "recursive_count")] recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let filtered_variants = variants.iter().filter(not_skipped);

Expand Down Expand Up @@ -277,8 +302,18 @@ fn gen_arbitrary_method(
(!variants.is_empty())
.then(|| {
// TODO: Improve dealing with `u` vs. `&mut u`.
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
let arbitrary = arbitrary_enum_method(
quote! { u },
&variants,
#[cfg(feature = "recursive_count")]
recursive_count,
);
let arbitrary_take_rest = arbitrary_enum_method(
quote! { &mut u },
&variants_take_rest,
#[cfg(feature = "recursive_count")]
recursive_count,
);

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Expand All @@ -298,14 +333,27 @@ fn gen_arbitrary_method(

let ident = &input.ident;
match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
Data::Struct(data) => arbitrary_structlike(
&data.fields,
ident,
lifetime,
#[cfg(feature = "recursive_count")]
recursive_count,
),
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
#[cfg(feature = "recursive_count")]
recursive_count,
),
Data::Enum(data) => arbitrary_enum(
data,
ident,
lifetime,
#[cfg(feature = "recursive_count")]
recursive_count,
),
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use super::{Arbitrary, Result, Unstructured};

#[cfg(feature = "std")]
use {
alloc::vec,
core::{fmt::Debug, hash::Hash},
std::collections::HashSet,
};

#[cfg(feature = "alloc")]
use alloc::{boxed::Box, rc::Rc, string::String, vec, vec::Vec};
use alloc::{boxed::Box, rc::Rc, string::String, vec::Vec};

#[cfg(all(feature = "alloc", target_has_atomic = "ptr"))]
use alloc::sync::Arc;
Expand Down
1 change: 1 addition & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ fn two_lifetimes() {
assert_eq!(upper, None);
}

#[cfg(feature = "recursive_count")]
#[test]
fn recursive_and_empty_input() {
// None of the following derives should result in a stack overflow. See
Expand Down