Skip to content

Commit 617ec10

Browse files
authored
Merge pull request #188 from sivizius/skip-variants
feat(derive): add variant-attribute `#[arbitrary(skip)]`
2 parents 84e6920 + 38be8f8 commit 617ec10

File tree

5 files changed

+156
-71
lines changed

5 files changed

+156
-71
lines changed

derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ rust-version = "1.63.0"
2121
[dependencies]
2222
proc-macro2 = "1.0"
2323
quote = "1.0"
24-
syn = { version = "2", features = ['derive', 'parsing'] }
24+
syn = { version = "2", features = ['derive', 'parsing', 'extra-traits'] }
2525

2626
[lib]
2727
proc_macro = true

derive/src/lib.rs

Lines changed: 103 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ use syn::*;
66

77
mod container_attributes;
88
mod field_attributes;
9+
mod variant_attributes;
10+
911
use container_attributes::ContainerAttributes;
1012
use field_attributes::{determine_field_constructor, FieldConstructor};
13+
use variant_attributes::not_skipped;
1114

12-
static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
13-
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
15+
const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16+
const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
1417

1518
#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
1619
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
@@ -201,81 +204,107 @@ fn gen_arbitrary_method(
201204
})
202205
}
203206

204-
let ident = &input.ident;
205-
let output = match &input.data {
206-
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
207-
Data::Union(data) => arbitrary_structlike(
208-
&Fields::Named(data.fields.clone()),
209-
ident,
210-
lifetime,
207+
fn arbitrary_variant(
208+
index: u64,
209+
enum_name: &Ident,
210+
variant_name: &Ident,
211+
ctor: TokenStream,
212+
) -> TokenStream {
213+
quote! { #index => #enum_name::#variant_name #ctor }
214+
}
215+
216+
fn arbitrary_enum_method(
217+
recursive_count: &syn::Ident,
218+
unstructured: TokenStream,
219+
variants: &[TokenStream],
220+
) -> impl quote::ToTokens {
221+
let count = variants.len() as u64;
222+
with_recursive_count_guard(
211223
recursive_count,
212-
)?,
213-
Data::Enum(data) => {
214-
let variants: Vec<TokenStream> = data
215-
.variants
216-
.iter()
217-
.enumerate()
218-
.map(|(i, variant)| {
219-
check_variant_attrs(variant)?;
220-
let idx = i as u64;
221-
let variant_name = &variant.ident;
222-
construct(&variant.fields, |_, field| gen_constructor_for_field(field))
223-
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
224+
quote! {
225+
// Use a multiply + shift to generate a ranged random number
226+
// with slight bias. For details, see:
227+
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
228+
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
229+
#(#variants,)*
230+
_ => unreachable!()
224231
})
225-
.collect::<Result<_>>()?;
232+
},
233+
)
234+
}
226235

227-
let variants_take_rest: Vec<TokenStream> = data
228-
.variants
229-
.iter()
230-
.enumerate()
231-
.map(|(i, variant)| {
232-
let idx = i as u64;
233-
let variant_name = &variant.ident;
234-
construct_take_rest(&variant.fields)
235-
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
236-
})
237-
.collect::<Result<_>>()?;
236+
fn arbitrary_enum(
237+
DataEnum { variants, .. }: &DataEnum,
238+
enum_name: &Ident,
239+
lifetime: LifetimeParam,
240+
recursive_count: &syn::Ident,
241+
) -> Result<TokenStream> {
242+
let filtered_variants = variants.iter().filter(not_skipped);
243+
244+
// Check attributes of all variants:
245+
filtered_variants
246+
.clone()
247+
.try_for_each(check_variant_attrs)?;
248+
249+
// From here on, we can assume that the attributes of all variants were checked.
250+
let enumerated_variants = filtered_variants
251+
.enumerate()
252+
.map(|(index, variant)| (index as u64, variant));
253+
254+
// Construct `match`-arms for the `arbitrary` method.
255+
let variants = enumerated_variants
256+
.clone()
257+
.map(|(index, Variant { fields, ident, .. })| {
258+
construct(fields, |_, field| gen_constructor_for_field(field))
259+
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
260+
})
261+
.collect::<Result<Vec<TokenStream>>>()?;
238262

239-
let count = data.variants.len() as u64;
263+
// Construct `match`-arms for the `arbitrary_take_rest` method.
264+
let variants_take_rest = enumerated_variants
265+
.map(|(index, Variant { fields, ident, .. })| {
266+
construct_take_rest(fields)
267+
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
268+
})
269+
.collect::<Result<Vec<TokenStream>>>()?;
270+
271+
// Most of the time, `variants` is not empty (the happy path),
272+
// thus `variants_take_rest` will be used,
273+
// so no need to move this check before constructing `variants_take_rest`.
274+
// If `variants` is empty, this will emit a compiler-error.
275+
(!variants.is_empty())
276+
.then(|| {
277+
// TODO: Improve dealing with `u` vs. `&mut u`.
278+
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
279+
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
240280

241-
let arbitrary = with_recursive_count_guard(
242-
recursive_count,
243-
quote! {
244-
// Use a multiply + shift to generate a ranged random number
245-
// with slight bias. For details, see:
246-
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
247-
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
248-
#(#variants,)*
249-
_ => unreachable!()
250-
})
251-
},
252-
);
253-
254-
let arbitrary_take_rest = with_recursive_count_guard(
255-
recursive_count,
256281
quote! {
257-
// Use a multiply + shift to generate a ranged random number
258-
// with slight bias. For details, see:
259-
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
260-
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
261-
#(#variants_take_rest,)*
262-
_ => unreachable!()
263-
})
264-
},
265-
);
282+
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
283+
#arbitrary
284+
}
266285

267-
quote! {
268-
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
269-
#arbitrary
286+
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
287+
#arbitrary_take_rest
288+
}
270289
}
290+
})
291+
.ok_or_else(|| Error::new_spanned(
292+
enum_name,
293+
"Enum must have at least one variant, that is not skipped"
294+
))
295+
}
271296

272-
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
273-
#arbitrary_take_rest
274-
}
275-
}
276-
}
277-
};
278-
Ok(output)
297+
let ident = &input.ident;
298+
match &input.data {
299+
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
300+
Data::Union(data) => arbitrary_structlike(
301+
&Fields::Named(data.fields.clone()),
302+
ident,
303+
lifetime,
304+
recursive_count,
305+
),
306+
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
307+
}
279308
}
280309

281310
fn construct(
@@ -375,7 +404,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
375404
Data::Enum(data) => data
376405
.variants
377406
.iter()
378-
.map(|v| size_hint_fields(&v.fields))
407+
.filter(not_skipped)
408+
.map(|Variant { fields, .. }| {
409+
// The attributes of all variants are checked in `gen_arbitrary_method` above
410+
// and can therefore assume that they are valid.
411+
size_hint_fields(fields)
412+
})
379413
.collect::<Result<Vec<TokenStream>>>()
380414
.map(|variants| {
381415
quote! {

derive/src/variant_attributes.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use crate::ARBITRARY_ATTRIBUTE_NAME;
2+
use syn::*;
3+
4+
pub fn not_skipped(variant: &&Variant) -> bool {
5+
!should_skip(variant)
6+
}
7+
8+
fn should_skip(Variant { attrs, .. }: &Variant) -> bool {
9+
attrs
10+
.iter()
11+
.filter_map(|attr| {
12+
attr.path()
13+
.is_ident(ARBITRARY_ATTRIBUTE_NAME)
14+
.then(|| attr.parse_args::<Meta>())
15+
.and_then(Result::ok)
16+
})
17+
.any(|meta| match meta {
18+
Meta::Path(path) => path.is_ident("skip"),
19+
_ => false,
20+
})
21+
}

examples/derive_enum.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ use arbitrary::{Arbitrary, Unstructured};
1212
enum MyEnum {
1313
UnitVariant,
1414
TupleVariant(bool, u32),
15-
StructVariant { x: i8, y: (u8, i32) },
15+
StructVariant {
16+
x: i8,
17+
y: (u8, i32),
18+
},
19+
20+
#[arbitrary(skip)]
21+
SkippedVariant(usize),
1622
}
1723

1824
fn main() {

tests/derive.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,30 @@ fn derive_enum() {
116116
assert_eq!((4, Some(17)), <MyEnum as Arbitrary>::size_hint(0));
117117
}
118118

119+
// This should result in a compiler-error:
120+
// #[derive(Arbitrary, Debug)]
121+
// enum Never {
122+
// #[arbitrary(skip)]
123+
// Nope,
124+
// }
125+
126+
#[derive(Arbitrary, Debug)]
127+
enum SkipVariant {
128+
Always,
129+
#[arbitrary(skip)]
130+
Never,
131+
}
132+
133+
#[test]
134+
fn test_skip_variant() {
135+
(0..=u8::MAX).for_each(|byte| {
136+
let buffer = [byte];
137+
let unstructured = Unstructured::new(&buffer);
138+
let skip_variant = SkipVariant::arbitrary_take_rest(unstructured).unwrap();
139+
assert!(!matches!(skip_variant, SkipVariant::Never));
140+
})
141+
}
142+
119143
#[derive(Arbitrary, Debug)]
120144
enum RecursiveTree {
121145
Leaf,

0 commit comments

Comments
 (0)