Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 15 additions & 29 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,31 +158,6 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
generics
}

fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream {
quote! {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
#recursive_count.with(|count| {
if count.get() > 0 {
return Err(arbitrary::Error::NotEnoughData);
}
count.set(count.get() + 1);
Ok(())
})?;
}

let result = (|| { #expr })();

if guard_against_recursion {
#recursive_count.with(|count| {
count.set(count.get() - 1);
});
}

result
}
}

fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeParam,
Expand All @@ -195,11 +170,18 @@ fn gen_arbitrary_method(
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
let body = quote! {
arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
Ok(#ident #arbitrary)
})
};

let arbitrary_take_rest = construct_take_rest(fields)?;
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
let take_rest_body = quote! {
arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
Ok(#ident #arbitrary_take_rest)
})
};

Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
Expand Down Expand Up @@ -243,7 +225,11 @@ fn gen_arbitrary_method(
};

if needs_recursive_count {
with_recursive_count_guard(recursive_count, do_variants)
quote! {
arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
#do_variants
})
}
} else {
do_variants
}
Expand Down
55 changes: 55 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,58 @@ mod test {
/// ```
#[cfg(all(doctest, feature = "derive"))]
pub struct CompileFailTests;

// Support for `#[derive(Arbitrary)]`.
#[doc(hidden)]
#[cfg(feature = "derive")]
pub mod details {
use super::*;

// Hidden trait that papers over the difference between `&mut Unstructured` and
// `Unstructured` arguments so that `with_recursive_count` can be used for both
// `arbitrary` and `arbitrary_take_rest`.
pub trait IsEmpty {
fn is_empty(&self) -> bool;
}

impl IsEmpty for Unstructured<'_> {
fn is_empty(&self) -> bool {
Unstructured::is_empty(self)
}
}

impl IsEmpty for &mut Unstructured<'_> {
fn is_empty(&self) -> bool {
Unstructured::is_empty(self)
}
}

// Calls `f` with a recursive count guard.
#[inline]
pub fn with_recursive_count<U: IsEmpty, R>(
u: U,
recursive_count: &'static std::thread::LocalKey<std::cell::Cell<u32>>,
f: impl FnOnce(U) -> Result<R>,
) -> Result<R> {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
recursive_count.with(|count| {
if count.get() > 0 {
return Err(Error::NotEnoughData);
}
count.set(count.get() + 1);
Ok(())
})?;
}

let result = f(u);

if guard_against_recursion {
recursive_count.with(|count| {
count.set(count.get() - 1);
});
}

result
}
}