Skip to content

Commit 7cfc405

Browse files
authored
chore: Fix fuzzers baseline nullability for compare and take (#5255)
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent f0d6048 commit 7cfc405

File tree

3 files changed

+60
-65
lines changed

3 files changed

+60
-65
lines changed

fuzz/src/array/compare.rs

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::fmt::Debug;
5-
use std::ops::Deref;
6-
74
use vortex_array::accessor::ArrayAccessor;
8-
use vortex_array::arrays::BoolArray;
5+
use vortex_array::arrays::{BoolArray, NativeValue};
96
use vortex_array::compute::{Operator, scalar_cmp};
107
use vortex_array::validity::Validity;
118
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
129
use vortex_buffer::BitBuffer;
13-
use vortex_dtype::{
14-
DType, NativeDecimalType, NativePType, match_each_decimal_value_type, match_each_native_ptype,
15-
};
10+
use vortex_dtype::{DType, Nullability, match_each_decimal_value_type, match_each_native_ptype};
1611
use vortex_error::{VortexExpect, VortexResult, vortex_err};
1712
use vortex_scalar::Scalar;
1813

@@ -29,6 +24,8 @@ pub fn compare_canonical_array(
2924
.into_array());
3025
}
3126

27+
let result_nullability = array.dtype().nullability() | value.dtype().nullability();
28+
3229
match array.dtype() {
3330
DType::Bool(_) => {
3431
let bool = value
@@ -44,6 +41,7 @@ pub fn compare_canonical_array(
4441
.map(|(b, v)| v.then_some(b)),
4542
bool,
4643
operator,
44+
result_nullability,
4745
))
4846
}
4947
DType::Primitive(p, _) => {
@@ -53,15 +51,16 @@ pub fn compare_canonical_array(
5351
let pval = primitive
5452
.typed_value::<P>()
5553
.vortex_expect("nulls handled before");
56-
Ok(compare_native_ptype(
54+
Ok(compare_to(
5755
primitive_array
5856
.as_slice::<P>()
5957
.iter()
6058
.copied()
6159
.zip(array.validity_mask().to_bit_buffer().iter())
62-
.map(|(b, v)| v.then_some(b)),
63-
pval,
60+
.map(|(b, v)| v.then_some(NativeValue(b))),
61+
NativeValue(pval),
6462
operator,
63+
result_nullability,
6564
))
6665
})
6766
}
@@ -75,14 +74,15 @@ pub fn compare_canonical_array(
7574
.cast::<D>()
7675
.ok_or_else(|| vortex_err!("todo: handle upcast of decimal array"))?;
7776
let buf = decimal_array.buffer::<D>();
78-
Ok(compare_native_decimal_type(
77+
Ok(compare_to(
7978
buf.as_slice()
8079
.iter()
8180
.copied()
8281
.zip(array.validity_mask().to_bit_buffer().iter())
8382
.map(|(b, v)| v.then_some(b)),
8483
dval,
8584
operator,
85+
result_nullability,
8686
))
8787
})
8888
}
@@ -93,8 +93,9 @@ pub fn compare_canonical_array(
9393
.vortex_expect("nulls handled before");
9494
compare_to(
9595
iter.map(|v| v.map(|b| unsafe { str::from_utf8_unchecked(b) })),
96-
utf8_value.deref(),
96+
&utf8_value,
9797
operator,
98+
result_nullability,
9899
)
99100
}),
100101
DType::Binary(_) => array.to_varbinview().with_iterator(|iter| {
@@ -106,8 +107,9 @@ pub fn compare_canonical_array(
106107
// Don't understand the lifetime problem here but identity map makes it go away
107108
#[allow(clippy::map_identity)]
108109
iter.map(|v| v),
109-
binary_value.deref(),
110+
&binary_value,
110111
operator,
112+
result_nullability,
111113
)
112114
}),
113115
DType::Struct(..) | DType::List(..) | DType::FixedSizeList(..) => {
@@ -125,56 +127,29 @@ pub fn compare_canonical_array(
125127
}
126128
}
127129

128-
fn compare_to<T: PartialOrd + PartialEq + Debug>(
129-
values: impl Iterator<Item = Option<T>>,
130-
cmp_value: T,
131-
operator: Operator,
132-
) -> ArrayRef {
133-
BoolArray::from_iter(values.map(|val| {
134-
val.map(|v| match operator {
135-
Operator::Eq => v == cmp_value,
136-
Operator::NotEq => v != cmp_value,
137-
Operator::Gt => v > cmp_value,
138-
Operator::Gte => v >= cmp_value,
139-
Operator::Lt => v < cmp_value,
140-
Operator::Lte => v <= cmp_value,
141-
})
142-
}))
143-
.into_array()
144-
}
145-
146-
fn compare_native_ptype<T: NativePType>(
130+
fn compare_to<T: PartialOrd>(
147131
values: impl Iterator<Item = Option<T>>,
148132
cmp_value: T,
149133
operator: Operator,
134+
nullability: Nullability,
150135
) -> ArrayRef {
151-
BoolArray::from_iter(values.map(|val| {
152-
val.map(|v| match operator {
153-
Operator::Eq => v.is_eq(cmp_value),
154-
Operator::NotEq => !v.is_eq(cmp_value),
155-
Operator::Gt => v.is_gt(cmp_value),
156-
Operator::Gte => v.is_ge(cmp_value),
157-
Operator::Lt => v.is_lt(cmp_value),
158-
Operator::Lte => v.is_le(cmp_value),
159-
})
160-
}))
161-
.into_array()
162-
}
136+
let eval_fn = |v| match operator {
137+
Operator::Eq => v == cmp_value,
138+
Operator::NotEq => v != cmp_value,
139+
Operator::Gt => v > cmp_value,
140+
Operator::Gte => v >= cmp_value,
141+
Operator::Lt => v < cmp_value,
142+
Operator::Lte => v <= cmp_value,
143+
};
163144

164-
fn compare_native_decimal_type<D: NativeDecimalType>(
165-
values: impl Iterator<Item = Option<D>>,
166-
cmp_value: D,
167-
operator: Operator,
168-
) -> ArrayRef {
169-
BoolArray::from_iter(values.map(|val| {
170-
val.map(|v| match operator {
171-
Operator::Eq => v == cmp_value,
172-
Operator::NotEq => v != cmp_value,
173-
Operator::Gt => v > cmp_value,
174-
Operator::Gte => v >= cmp_value,
175-
Operator::Lt => v < cmp_value,
176-
Operator::Lte => v <= cmp_value,
177-
})
178-
}))
179-
.into_array()
145+
if !nullability.is_nullable() {
146+
BoolArray::from_iter(
147+
values
148+
.map(|val| val.vortex_expect("non nullable"))
149+
.map(eval_fn),
150+
)
151+
.into_array()
152+
} else {
153+
BoolArray::from_iter(values.map(|val| val.map(eval_fn))).into_array()
154+
}
180155
}

fuzz/src/array/mod.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,23 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
180180
}
181181

182182
let indices = random_vec_in_range(u, 0, current_array.len() - 1)?;
183+
let nullable = indices.contains(&None);
184+
183185
current_array = take_canonical_array(&current_array, &indices).vortex_unwrap();
184-
let indices_array = PrimitiveArray::from_option_iter(
185-
indices.iter().map(|i| i.map(|i| i as u64)),
186-
)
187-
.into_array();
186+
let indices_array = if nullable {
187+
PrimitiveArray::from_option_iter(
188+
indices.iter().map(|i| i.map(|i| i as u64)),
189+
)
190+
.into_array()
191+
} else {
192+
PrimitiveArray::from_iter(
193+
indices
194+
.iter()
195+
.map(|i| i.vortex_expect("must be present"))
196+
.map(|i| i as u64),
197+
)
198+
.into_array()
199+
};
188200

189201
let compressed = BtrBlocksCompressor::default()
190202
.compress(&indices_array)

vortex-array/src/arrays/primitive/native_value.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::cmp::Ordering;
5+
46
use vortex_dtype::{NativePType, half};
57

6-
/// NativeValue serves as a wrapper type to allow us to implement Hash and Eq on all primitive types.
8+
/// NativeValue serves as a wrapper type to allow us to implement Hash, Eq and other traits on all primitive types.
79
///
810
/// Rust does not define Hash/Eq for any of the float types due to the presence of
911
/// NaN and +/- 0. We don't care about storing multiple NaNs or zeros in our dictionaries,
@@ -30,6 +32,12 @@ macro_rules! prim_value {
3032
};
3133
}
3234

35+
impl<T: NativePType> PartialOrd<NativeValue<T>> for NativeValue<T> {
36+
fn partial_cmp(&self, other: &NativeValue<T>) -> Option<Ordering> {
37+
Some(self.0.total_compare(other.0))
38+
}
39+
}
40+
3341
macro_rules! float_value {
3442
($typ:ty) => {
3543
impl core::hash::Hash for NativeValue<$typ> {

0 commit comments

Comments
 (0)