Skip to content

Commit 2c76f24

Browse files
chore[fuzz]: add scalar_at action (#5043)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent da3d63d commit 2c76f24

File tree

3 files changed

+123
-6
lines changed

3 files changed

+123
-6
lines changed

fuzz/fuzz_targets/array_ops.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
103103
current_array = mask(&current_array, &mask_val).vortex_unwrap();
104104
assert_array_eq(&expected.array(), &current_array, i).unwrap();
105105
}
106+
Action::ScalarAt(indices) => {
107+
let expected_scalars = expected.scalar_vec();
108+
for (j, &idx) in indices.iter().enumerate() {
109+
let scalar = current_array.scalar_at(idx);
110+
assert_scalar_eq(&expected_scalars[j], &scalar);
111+
}
112+
}
106113
}
107114
}
108115
Corpus::Keep

fuzz/src/array/mod.rs

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod fill_null;
77
mod filter;
88
mod mask;
99
mod min_max;
10+
mod scalar_at;
1011
mod search_sorted;
1112
mod slice;
1213
mod sort;
@@ -24,6 +25,7 @@ use libfuzzer_sys::arbitrary::Error::EmptyChoose;
2425
use libfuzzer_sys::arbitrary::{Arbitrary, Unstructured};
2526
pub(crate) use mask::*;
2627
pub(crate) use min_max::*;
28+
pub(crate) use scalar_at::*;
2729
pub(crate) use search_sorted::*;
2830
pub(crate) use slice::*;
2931
pub use sort::sort_canonical_array;
@@ -80,6 +82,8 @@ pub enum Action {
8082
MinMax,
8183
FillNull(Scalar),
8284
Mask(Mask),
85+
// Here we want to try multiple values.
86+
ScalarAt(Vec<usize>),
8387
}
8488

8589
#[derive(Debug)]
@@ -88,6 +92,7 @@ pub enum ExpectedValue {
8892
Search(SearchResult),
8993
Scalar(Scalar),
9094
MinMax(Option<MinMaxResult>),
95+
ScalarVec(Vec<Scalar>),
9196
}
9297

9398
impl ExpectedValue {
@@ -118,6 +123,13 @@ impl ExpectedValue {
118123
_ => vortex_panic!("expected min_max"),
119124
}
120125
}
126+
127+
pub fn scalar_vec(self) -> Vec<Scalar> {
128+
match self {
129+
ExpectedValue::ScalarVec(v) => v,
130+
_ => vortex_panic!("expected scalar_vec"),
131+
}
132+
}
121133
}
122134

123135
impl<'a> Arbitrary<'a> for FuzzArrayAction {
@@ -131,7 +143,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
131143
valid_actions.sort_unstable_by_key(|a| *a as usize);
132144

133145
let mut actions = Vec::new();
134-
let action_count = u.int_in_range(1..=4)?;
146+
let action_count = u.int_in_range(1..=4.min(valid_actions.len()))?;
135147
for _ in 0..action_count {
136148
let action_type = random_action_from_list(u, valid_actions.as_slice())?;
137149

@@ -313,6 +325,35 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
313325
ExpectedValue::Array(expected_result),
314326
)
315327
}
328+
ActionType::ScalarAt => {
329+
if current_array.is_empty() {
330+
return Err(EmptyChoose);
331+
}
332+
333+
let num_indices = u.int_in_range(1..=5.min(current_array.len()))?;
334+
let mut indices = HashSet::with_capacity(num_indices);
335+
336+
while indices.len() < num_indices {
337+
let idx = u.choose_index(current_array.len())?;
338+
indices.insert(idx);
339+
}
340+
341+
let indices_vec: Vec<usize> = indices.into_iter().collect();
342+
343+
// Compute expected scalars using the baseline implementation
344+
let expected_scalars: Vec<Scalar> = indices_vec
345+
.iter()
346+
.map(|&idx| {
347+
scalar_at_canonical_array(current_array.to_canonical(), idx)
348+
.vortex_unwrap()
349+
})
350+
.collect();
351+
352+
(
353+
Action::ScalarAt(indices_vec),
354+
ExpectedValue::ScalarVec(expected_scalars),
355+
)
356+
}
316357
})
317358
}
318359

@@ -325,23 +366,23 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
325366

326367
match dtype {
327368
DType::Struct(sdt, _) => {
328-
// Struct supports: Compress, Slice, Take, Filter, MinMax, Mask
369+
// Struct supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
329370
// Does NOT support: SearchSorted (requires scalar comparison), Compare, Cast, Sum, FillNull
330-
let struct_actions = [Compress, Slice, Take, Filter, MinMax, Mask];
371+
let struct_actions = [Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt];
331372
sdt.fields()
332373
.map(|child| actions_for_dtype(&child))
333374
.fold(struct_actions.into(), |acc, actions| {
334375
acc.intersection(&actions).copied().collect()
335376
})
336377
}
337378
DType::List(..) | DType::FixedSizeList(..) => {
338-
// List supports: Compress, Slice, Take, Filter, MinMax, Mask
379+
// List supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
339380
// Does NOT support: SearchSorted, Compare, Cast, Sum, FillNull
340-
[Compress, Slice, Take, Filter, MinMax, Mask].into()
381+
[Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt].into()
341382
}
342383
DType::Utf8(_) | DType::Binary(_) => {
343384
// Utf8/Binary supports everything except Sum
344-
// Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, FillNull, Mask
385+
// Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, FillNull, Mask, ScalarAt
345386
[
346387
Compress,
347388
Slice,
@@ -353,6 +394,7 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
353394
MinMax,
354395
FillNull,
355396
Mask,
397+
ScalarAt,
356398
]
357399
.into()
358400
}
@@ -372,6 +414,7 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
372414
Cast,
373415
FillNull,
374416
Mask,
417+
ScalarAt,
375418
]
376419
.into()
377420
}

fuzz/src/array/scalar_at.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::sync::Arc;
5+
6+
use vortex_array::arrays::varbin_scalar;
7+
use vortex_array::{Array, Canonical};
8+
use vortex_dtype::{DType, match_each_native_ptype};
9+
use vortex_error::{VortexResult, VortexUnwrap};
10+
use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
11+
12+
/// Baseline implementation of scalar_at that works on canonical arrays.
13+
/// This implementation manually extracts the scalar value from each canonical type
14+
/// without using the scalar_at method, to serve as an independent baseline for testing.
15+
pub fn scalar_at_canonical_array(canonical: Canonical, index: usize) -> VortexResult<Scalar> {
16+
Ok(match canonical {
17+
Canonical::Null(_array) => Scalar::null(DType::Null),
18+
Canonical::Bool(array) => {
19+
Scalar::bool(array.bit_buffer().value(index), array.dtype().nullability())
20+
}
21+
Canonical::Primitive(array) => {
22+
match_each_native_ptype!(array.ptype(), |T| {
23+
Scalar::primitive(array.as_slice::<T>()[index], array.dtype().nullability())
24+
})
25+
}
26+
Canonical::Decimal(array) => {
27+
match_each_decimal_value_type!(array.values_type(), |D| {
28+
Scalar::decimal(
29+
DecimalValue::from(array.buffer::<D>()[index]),
30+
array.decimal_dtype(),
31+
array.dtype().nullability(),
32+
)
33+
})
34+
}
35+
Canonical::VarBinView(array) => varbin_scalar(array.bytes_at(index), array.dtype()),
36+
Canonical::List(array) => {
37+
let list = array.list_elements_at(index);
38+
let children: Vec<Scalar> = (0..list.len())
39+
.map(|i| scalar_at_canonical_array(list.to_canonical(), i).vortex_unwrap())
40+
.collect();
41+
Scalar::list(
42+
Arc::new(list.dtype().clone()),
43+
children,
44+
array.dtype().nullability(),
45+
)
46+
}
47+
Canonical::FixedSizeList(array) => {
48+
let list = array.fixed_size_list_elements_at(index);
49+
let children: Vec<Scalar> = (0..list.len())
50+
.map(|i| scalar_at_canonical_array(list.to_canonical(), i).vortex_unwrap())
51+
.collect();
52+
Scalar::fixed_size_list(list.dtype().clone(), children, array.dtype().nullability())
53+
}
54+
Canonical::Struct(array) => {
55+
let field_scalars: Vec<Scalar> = array
56+
.fields()
57+
.iter()
58+
.map(|field| scalar_at_canonical_array(field.to_canonical(), index).vortex_unwrap())
59+
.collect();
60+
Scalar::struct_(array.dtype().clone(), field_scalars)
61+
}
62+
Canonical::Extension(array) => {
63+
let storage_scalar = scalar_at_canonical_array(array.storage().to_canonical(), index)?;
64+
Scalar::extension(array.ext_dtype().clone(), storage_scalar)
65+
}
66+
})
67+
}

0 commit comments

Comments
 (0)