Skip to content

Commit 469b801

Browse files
chore[fuzz]: add mask_fn baseline (#5047)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 57231d4 commit 469b801

File tree

1 file changed

+333
-6
lines changed

1 file changed

+333
-6
lines changed

fuzz/src/array/mask.rs

Lines changed: 333 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,340 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_array::compute::mask as mask_fn;
5-
use vortex_array::{ArrayRef, Canonical};
6-
use vortex_error::VortexResult;
7-
use vortex_mask::Mask;
4+
use std::ops::Not;
5+
use std::sync::Arc;
6+
7+
use vortex_array::arrays::{
8+
BoolArray, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray, PrimitiveArray,
9+
StructArray, VarBinViewArray,
10+
};
11+
use vortex_array::validity::Validity;
12+
use vortex_array::vtable::ValidityHelper;
13+
use vortex_array::{ArrayRef, Canonical, IntoArray, ToCanonical};
14+
use vortex_dtype::ExtDType;
15+
use vortex_error::{VortexResult, VortexUnwrap};
16+
use vortex_mask::{AllOr, Mask};
17+
use vortex_scalar::match_each_decimal_value_type;
818

919
/// Apply mask on the canonical form of the array to get a consistent baseline.
20+
/// This implementation manually applies the mask to each canonical type
21+
/// without using the mask_fn method, to serve as an independent baseline for testing.
1022
pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<ArrayRef> {
11-
// TODO(joe): replace with baseline not using canonical
12-
mask_fn(canonical.as_ref(), mask)
23+
Ok(match canonical {
24+
Canonical::Null(array) => {
25+
// Null arrays are already all invalid, masking has no effect
26+
array.into_array()
27+
}
28+
Canonical::Bool(array) => {
29+
let new_validity = apply_mask_to_validity(array.validity(), mask);
30+
BoolArray::from_bit_buffer(array.bit_buffer().clone(), new_validity).into_array()
31+
}
32+
Canonical::Primitive(array) => {
33+
let new_validity = apply_mask_to_validity(array.validity(), mask);
34+
PrimitiveArray::from_byte_buffer(
35+
array.byte_buffer().clone(),
36+
array.ptype(),
37+
new_validity,
38+
)
39+
.into_array()
40+
}
41+
Canonical::Decimal(array) => {
42+
let new_validity = apply_mask_to_validity(array.validity(), mask);
43+
match_each_decimal_value_type!(array.values_type(), |D| {
44+
DecimalArray::new(array.buffer::<D>(), array.decimal_dtype(), new_validity)
45+
.into_array()
46+
})
47+
}
48+
Canonical::VarBinView(array) => {
49+
let new_validity = apply_mask_to_validity(array.validity(), mask);
50+
VarBinViewArray::new(
51+
array.views().clone(),
52+
array.buffers().clone(),
53+
array.dtype().as_nullable(),
54+
new_validity,
55+
)
56+
.into_array()
57+
}
58+
Canonical::List(array) => {
59+
let new_validity = apply_mask_to_validity(array.validity(), mask);
60+
ListViewArray::try_new(
61+
array.elements().clone(),
62+
array.offsets().clone(),
63+
array.sizes().clone(),
64+
new_validity,
65+
)
66+
.vortex_unwrap()
67+
.into_array()
68+
}
69+
Canonical::FixedSizeList(array) => {
70+
let new_validity = apply_mask_to_validity(array.validity(), mask);
71+
FixedSizeListArray::new(
72+
array.elements().clone(),
73+
array.list_size(),
74+
new_validity,
75+
array.len(),
76+
)
77+
.into_array()
78+
}
79+
Canonical::Struct(array) => {
80+
let new_validity = apply_mask_to_validity(array.validity(), mask);
81+
StructArray::try_new_with_dtype(
82+
array.fields().clone(),
83+
array.struct_fields().clone(),
84+
array.len(),
85+
new_validity,
86+
)
87+
.vortex_unwrap()
88+
.into_array()
89+
}
90+
Canonical::Extension(array) => {
91+
// Recursively mask the storage array
92+
let masked_storage =
93+
mask_canonical_array(array.storage().to_canonical(), mask).vortex_unwrap();
94+
95+
if masked_storage.dtype().nullability()
96+
== array.ext_dtype().storage_dtype().nullability()
97+
{
98+
ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array()
99+
} else {
100+
// The storage dtype changed (i.e., became nullable due to masking)
101+
let ext_dtype = Arc::new(ExtDType::new(
102+
array.ext_dtype().id().clone(),
103+
Arc::new(masked_storage.dtype().clone()),
104+
array.ext_dtype().metadata().cloned(),
105+
));
106+
ExtensionArray::new(ext_dtype, masked_storage).into_array()
107+
}
108+
}
109+
})
110+
}
111+
112+
fn apply_mask_to_validity(validity: &Validity, mask: &Mask) -> Validity {
113+
match mask.bit_buffer() {
114+
AllOr::All => Validity::AllInvalid,
115+
AllOr::None => validity.clone(),
116+
AllOr::Some(make_invalid) => match validity {
117+
Validity::NonNullable | Validity::AllValid => {
118+
Validity::Array(BoolArray::from(make_invalid.not()).into_array())
119+
}
120+
Validity::AllInvalid => Validity::AllInvalid,
121+
Validity::Array(is_valid) => {
122+
let is_valid = is_valid.to_bool();
123+
let keep_valid = make_invalid.not();
124+
Validity::from(is_valid.bit_buffer() & &keep_valid)
125+
}
126+
},
127+
}
128+
}
129+
130+
#[cfg(test)]
131+
mod tests {
132+
use vortex_array::arrays::{
133+
BoolArray, DecimalArray, FixedSizeListArray, ListViewArray, NullArray, PrimitiveArray,
134+
StructArray, VarBinViewArray,
135+
};
136+
use vortex_array::{Array, IntoArray};
137+
use vortex_dtype::{DecimalDType, FieldNames, Nullability};
138+
use vortex_mask::Mask;
139+
use vortex_scalar::Scalar;
140+
141+
use super::mask_canonical_array;
142+
143+
#[test]
144+
fn test_mask_null_array() {
145+
let array = NullArray::new(5);
146+
let mask = Mask::from_iter([true, false, true, false, true]);
147+
148+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
149+
150+
assert_eq!(result.len(), 5);
151+
// All values should still be null
152+
for i in 0..5 {
153+
assert!(!result.is_valid(i));
154+
}
155+
}
156+
157+
#[test]
158+
fn test_mask_bool_array() {
159+
let array = BoolArray::from_iter([true, false, true, false, true]);
160+
let mask = Mask::from_iter([true, false, false, true, false]);
161+
162+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
163+
164+
assert_eq!(result.len(), 5);
165+
assert!(!result.is_valid(0));
166+
assert_eq!(result.scalar_at(1), Scalar::from(Some(false)));
167+
assert_eq!(result.scalar_at(2), Scalar::from(Some(true)));
168+
assert!(!result.is_valid(3));
169+
assert_eq!(result.scalar_at(4), Scalar::from(Some(true)));
170+
}
171+
172+
#[test]
173+
fn test_mask_primitive_array() {
174+
let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
175+
let mask = Mask::from_iter([false, true, false, true, false]);
176+
177+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
178+
179+
assert_eq!(result.len(), 5);
180+
assert_eq!(result.scalar_at(0), Scalar::from(Some(1)));
181+
assert!(!result.is_valid(1));
182+
assert_eq!(result.scalar_at(2), Scalar::from(Some(3)));
183+
assert!(!result.is_valid(3));
184+
assert_eq!(result.scalar_at(4), Scalar::from(Some(5)));
185+
}
186+
187+
#[test]
188+
fn test_mask_primitive_array_with_nulls() {
189+
let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
190+
let mask = Mask::from_iter([true, false, false, true, false]);
191+
192+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
193+
194+
assert_eq!(result.len(), 5);
195+
assert!(!result.is_valid(0));
196+
assert!(!result.is_valid(1)); // was already null
197+
assert_eq!(result.scalar_at(2), Scalar::from(Some(3)));
198+
assert!(!result.is_valid(3));
199+
assert!(!result.is_valid(4)); // was already null
200+
}
201+
202+
#[test]
203+
fn test_mask_decimal_array() {
204+
let array = DecimalArray::from_option_iter(
205+
[Some(1i128), Some(2), Some(3), Some(4), Some(5)],
206+
DecimalDType::new(10, 2),
207+
);
208+
let mask = Mask::from_iter([false, false, true, false, false]);
209+
210+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
211+
212+
assert_eq!(result.len(), 5);
213+
assert!(result.is_valid(0));
214+
assert!(result.is_valid(1));
215+
assert!(!result.is_valid(2));
216+
assert!(result.is_valid(3));
217+
assert!(result.is_valid(4));
218+
}
219+
220+
#[test]
221+
fn test_mask_varbinview_array() {
222+
let array = VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]);
223+
let mask = Mask::from_iter([true, false, true, false, true]);
224+
225+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
226+
227+
assert_eq!(result.len(), 5);
228+
assert!(!result.is_valid(0));
229+
assert_eq!(
230+
result.scalar_at(1),
231+
Scalar::utf8("two", Nullability::Nullable)
232+
);
233+
assert!(!result.is_valid(2));
234+
assert_eq!(
235+
result.scalar_at(3),
236+
Scalar::utf8("four", Nullability::Nullable)
237+
);
238+
assert!(!result.is_valid(4));
239+
}
240+
241+
#[test]
242+
fn test_mask_list_array() {
243+
let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]).into_array();
244+
let offsets = PrimitiveArray::from_iter([0i32, 2, 4]).into_array();
245+
let sizes = PrimitiveArray::from_iter([2i32, 2, 2]).into_array();
246+
let array =
247+
ListViewArray::try_new(elements, offsets, sizes, Nullability::NonNullable.into())
248+
.unwrap();
249+
250+
let mask = Mask::from_iter([false, true, false]);
251+
252+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
253+
254+
assert_eq!(result.len(), 3);
255+
assert!(result.is_valid(0));
256+
assert!(!result.is_valid(1));
257+
assert!(result.is_valid(2));
258+
}
259+
260+
#[test]
261+
fn test_mask_fixed_size_list_array() {
262+
let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]).into_array();
263+
let array =
264+
FixedSizeListArray::try_new(elements, 2, Nullability::NonNullable.into(), 3).unwrap();
265+
266+
let mask = Mask::from_iter([true, false, true]);
267+
268+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
269+
270+
assert_eq!(result.len(), 3);
271+
assert!(!result.is_valid(0));
272+
assert!(result.is_valid(1));
273+
assert!(!result.is_valid(2));
274+
}
275+
276+
#[test]
277+
fn test_mask_struct_array() {
278+
let field1 = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
279+
let field2 = PrimitiveArray::from_iter([4i32, 5, 6]).into_array();
280+
let fields = vec![field1, field2];
281+
282+
let array = StructArray::try_new(
283+
FieldNames::from(["a", "b"]),
284+
fields,
285+
3,
286+
Nullability::NonNullable.into(),
287+
)
288+
.unwrap();
289+
290+
let mask = Mask::from_iter([false, true, false]);
291+
292+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
293+
294+
assert_eq!(result.len(), 3);
295+
assert!(result.is_valid(0));
296+
assert!(!result.is_valid(1));
297+
assert!(result.is_valid(2));
298+
}
299+
300+
#[test]
301+
fn test_mask_all_true() {
302+
let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
303+
let mask = Mask::AllTrue(5);
304+
305+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
306+
307+
assert_eq!(result.len(), 5);
308+
// All values should be masked out (null)
309+
for i in 0..5 {
310+
assert!(!result.is_valid(i));
311+
}
312+
}
313+
314+
#[test]
315+
fn test_mask_all_false() {
316+
let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
317+
let mask = Mask::AllFalse(5);
318+
319+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
320+
321+
assert_eq!(result.len(), 5);
322+
// No values should be masked out
323+
for i in 0..5 {
324+
assert!(result.is_valid(i));
325+
#[allow(clippy::cast_possible_truncation)]
326+
let expected = (i + 1) as i32;
327+
assert_eq!(result.scalar_at(i), Scalar::from(Some(expected)));
328+
}
329+
}
330+
331+
#[test]
332+
fn test_mask_empty_array() {
333+
let array = PrimitiveArray::from_iter(Vec::<i32>::new());
334+
let mask = Mask::AllFalse(0);
335+
336+
let result = mask_canonical_array(array.to_canonical(), &mask).unwrap();
337+
338+
assert_eq!(result.len(), 0);
339+
}
13340
}

0 commit comments

Comments
 (0)