Skip to content

Commit 4a85b18

Browse files
committed
chore: Fix fuzzers baseline nullability for compare and take
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent ca770ed commit 4a85b18

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

fuzz/src/array/compare.rs

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use std::fmt::Debug;
5-
use std::ops::Deref;
65

76
use vortex_array::accessor::ArrayAccessor;
87
use vortex_array::arrays::BoolArray;
@@ -11,7 +10,7 @@ use vortex_array::validity::Validity;
1110
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
1211
use vortex_buffer::BitBuffer;
1312
use vortex_dtype::{
14-
DType, NativeDecimalType, NativePType, match_each_decimal_value_type, match_each_native_ptype,
13+
DType, NativePType, Nullability, match_each_decimal_value_type, match_each_native_ptype,
1514
};
1615
use vortex_error::{VortexExpect, VortexResult, vortex_err};
1716
use vortex_scalar::Scalar;
@@ -29,6 +28,8 @@ pub fn compare_canonical_array(
2928
.into_array());
3029
}
3130

31+
let result_nullability = array.dtype().nullability() | value.dtype().nullability();
32+
3233
match array.dtype() {
3334
DType::Bool(_) => {
3435
let bool = value
@@ -44,6 +45,7 @@ pub fn compare_canonical_array(
4445
.map(|(b, v)| v.then_some(b)),
4546
bool,
4647
operator,
48+
result_nullability,
4749
))
4850
}
4951
DType::Primitive(p, _) => {
@@ -62,6 +64,7 @@ pub fn compare_canonical_array(
6264
.map(|(b, v)| v.then_some(b)),
6365
pval,
6466
operator,
67+
result_nullability,
6568
))
6669
})
6770
}
@@ -75,14 +78,15 @@ pub fn compare_canonical_array(
7578
.cast::<D>()
7679
.ok_or_else(|| vortex_err!("todo: handle upcast of decimal array"))?;
7780
let buf = decimal_array.buffer::<D>();
78-
Ok(compare_native_decimal_type(
81+
Ok(compare_to(
7982
buf.as_slice()
8083
.iter()
8184
.copied()
8285
.zip(array.validity_mask().to_bit_buffer().iter())
8386
.map(|(b, v)| v.then_some(b)),
8487
dval,
8588
operator,
89+
result_nullability,
8690
))
8791
})
8892
}
@@ -93,8 +97,9 @@ pub fn compare_canonical_array(
9397
.vortex_expect("nulls handled before");
9498
compare_to(
9599
iter.map(|v| v.map(|b| unsafe { str::from_utf8_unchecked(b) })),
96-
utf8_value.deref(),
100+
&utf8_value,
97101
operator,
102+
result_nullability,
98103
)
99104
}),
100105
DType::Binary(_) => array.to_varbinview().with_iterator(|iter| {
@@ -106,8 +111,9 @@ pub fn compare_canonical_array(
106111
// Don't understand the lifetime problem here but identity map makes it go away
107112
#[allow(clippy::map_identity)]
108113
iter.map(|v| v),
109-
binary_value.deref(),
114+
&binary_value,
110115
operator,
116+
result_nullability,
111117
)
112118
}),
113119
DType::Struct(..) | DType::List(..) | DType::FixedSizeList(..) => {
@@ -125,56 +131,48 @@ pub fn compare_canonical_array(
125131
}
126132
}
127133

134+
#[allow(clippy::unwrap_used)]
128135
fn compare_to<T: PartialOrd + PartialEq + Debug>(
129136
values: impl Iterator<Item = Option<T>>,
130137
cmp_value: T,
131138
operator: Operator,
139+
nullability: Nullability,
132140
) -> ArrayRef {
133-
BoolArray::from_iter(values.map(|val| {
134-
val.map(|v| match operator {
135-
Operator::Eq => v == cmp_value,
136-
Operator::NotEq => v != cmp_value,
137-
Operator::Gt => v > cmp_value,
138-
Operator::Gte => v >= cmp_value,
139-
Operator::Lt => v < cmp_value,
140-
Operator::Lte => v <= cmp_value,
141-
})
142-
}))
143-
.into_array()
141+
let eval_fn = |v| match operator {
142+
Operator::Eq => v == cmp_value,
143+
Operator::NotEq => v != cmp_value,
144+
Operator::Gt => v > cmp_value,
145+
Operator::Gte => v >= cmp_value,
146+
Operator::Lt => v < cmp_value,
147+
Operator::Lte => v <= cmp_value,
148+
};
149+
150+
if !nullability.is_nullable() {
151+
BoolArray::from_iter(values.map(|val| val.unwrap()).map(eval_fn)).into_array()
152+
} else {
153+
BoolArray::from_iter(values.map(|val| val.map(eval_fn))).into_array()
154+
}
144155
}
145156

157+
#[allow(clippy::unwrap_used)]
146158
fn compare_native_ptype<T: NativePType>(
147159
values: impl Iterator<Item = Option<T>>,
148160
cmp_value: T,
149161
operator: Operator,
162+
nullability: Nullability,
150163
) -> ArrayRef {
151-
BoolArray::from_iter(values.map(|val| {
152-
val.map(|v| match operator {
153-
Operator::Eq => v.is_eq(cmp_value),
154-
Operator::NotEq => !v.is_eq(cmp_value),
155-
Operator::Gt => v.is_gt(cmp_value),
156-
Operator::Gte => v.is_ge(cmp_value),
157-
Operator::Lt => v.is_lt(cmp_value),
158-
Operator::Lte => v.is_le(cmp_value),
159-
})
160-
}))
161-
.into_array()
162-
}
164+
let eval_fn = |v: T| match operator {
165+
Operator::Eq => v.is_eq(cmp_value),
166+
Operator::NotEq => !v.is_eq(cmp_value),
167+
Operator::Gt => v.is_gt(cmp_value),
168+
Operator::Gte => v.is_ge(cmp_value),
169+
Operator::Lt => v.is_lt(cmp_value),
170+
Operator::Lte => v.is_le(cmp_value),
171+
};
163172

164-
fn compare_native_decimal_type<D: NativeDecimalType>(
165-
values: impl Iterator<Item = Option<D>>,
166-
cmp_value: D,
167-
operator: Operator,
168-
) -> ArrayRef {
169-
BoolArray::from_iter(values.map(|val| {
170-
val.map(|v| match operator {
171-
Operator::Eq => v == cmp_value,
172-
Operator::NotEq => v != cmp_value,
173-
Operator::Gt => v > cmp_value,
174-
Operator::Gte => v >= cmp_value,
175-
Operator::Lt => v < cmp_value,
176-
Operator::Lte => v <= cmp_value,
177-
})
178-
}))
179-
.into_array()
173+
if !nullability.is_nullable() {
174+
BoolArray::from_iter(values.map(|val| val.unwrap()).map(eval_fn)).into_array()
175+
} else {
176+
BoolArray::from_iter(values.map(|val| val.map(eval_fn))).into_array()
177+
}
180178
}

fuzz/src/array/mod.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,23 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
180180
}
181181

182182
let indices = random_vec_in_range(u, 0, current_array.len() - 1)?;
183+
let nullable = indices.contains(&None);
184+
183185
current_array = take_canonical_array(&current_array, &indices).vortex_unwrap();
184-
let indices_array = PrimitiveArray::from_option_iter(
185-
indices.iter().map(|i| i.map(|i| i as u64)),
186-
)
187-
.into_array();
186+
let indices_array = if nullable {
187+
PrimitiveArray::from_option_iter(
188+
indices.iter().map(|i| i.map(|i| i as u64)),
189+
)
190+
.into_array()
191+
} else {
192+
PrimitiveArray::from_iter(
193+
indices
194+
.iter()
195+
.map(|i| i.vortex_expect("must be present"))
196+
.map(|i| i as u64),
197+
)
198+
.into_array()
199+
};
188200

189201
let compressed = BtrBlocksCompressor::default()
190202
.compress(&indices_array)

0 commit comments

Comments
 (0)