Skip to content

Commit 776742a

Browse files
chore[fuzz]: added decimal array ops (#3571)
Would be nice to be able to be generic over a `NativePType` and a `NativeDecimalType` --------- Signed-off-by: Joe Isaacs <[email protected]>
1 parent 07ea496 commit 776742a

File tree

10 files changed

+200
-19
lines changed

10 files changed

+200
-19
lines changed

fuzz/src/compare.rs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use vortex_array::compute::{Operator, scalar_cmp};
88
use vortex_array::validity::Validity;
99
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
1010
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11-
use vortex_error::{VortexExpect, VortexResult};
12-
use vortex_scalar::Scalar;
11+
use vortex_error::{VortexExpect, VortexResult, vortex_err};
12+
use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
1313

1414
pub fn compare_canonical_array(
1515
array: &dyn Array,
@@ -59,6 +59,27 @@ pub fn compare_canonical_array(
5959
))
6060
})
6161
}
62+
DType::Decimal(..) => {
63+
let decimal = value.as_decimal();
64+
let decimal_array = array.to_decimal()?;
65+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
66+
let dval = decimal
67+
.decimal_value()
68+
.vortex_expect("nulls handled before")
69+
.cast::<D>()
70+
.ok_or_else(|| vortex_err!("todo: handle upcast of decimal array"))?;
71+
let buf = decimal_array.buffer::<D>();
72+
Ok(compare_native_decimal_type(
73+
buf.as_slice()
74+
.iter()
75+
.copied()
76+
.zip(array.validity_mask()?.to_boolean_buffer().iter())
77+
.map(|(b, v)| v.then_some(b)),
78+
dval,
79+
operator,
80+
))
81+
})
82+
}
6283
DType::Utf8(_) => array.to_varbinview()?.with_iterator(|iter| {
6384
let utf8_value = value
6485
.as_utf8()
@@ -94,7 +115,9 @@ pub fn compare_canonical_array(
94115
)
95116
.into_array())
96117
}
97-
d => unreachable!("DType {d} not supported for fuzzing"),
118+
d @ (DType::Null | DType::Extension(_)) => {
119+
unreachable!("DType {d} not supported for fuzzing")
120+
}
98121
}
99122
}
100123

@@ -133,3 +156,21 @@ fn compare_native_ptype<T: NativePType>(
133156
}))
134157
.into_array()
135158
}
159+
160+
fn compare_native_decimal_type<D: NativeDecimalType>(
161+
values: impl Iterator<Item = Option<D>>,
162+
cmp_value: D,
163+
operator: Operator,
164+
) -> ArrayRef {
165+
BoolArray::from_iter(values.map(|val| {
166+
val.map(|v| match operator {
167+
Operator::Eq => v == cmp_value,
168+
Operator::NotEq => v != cmp_value,
169+
Operator::Gt => v > cmp_value,
170+
Operator::Gte => v >= cmp_value,
171+
Operator::Lt => v < cmp_value,
172+
Operator::Lte => v <= cmp_value,
173+
})
174+
}))
175+
.into_array()
176+
}

fuzz/src/filter.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use vortex_array::accessor::ArrayAccessor;
22
use vortex_array::arrays::{
3-
BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinViewArray,
3+
BoolArray, BooleanBuffer, DecimalArray, PrimitiveArray, StructArray, VarBinViewArray,
44
};
55
use vortex_array::validity::Validity;
66
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
77
use vortex_buffer::Buffer;
88
use vortex_dtype::{DType, match_each_native_ptype};
99
use vortex_error::VortexResult;
10+
use vortex_scalar::match_each_decimal_value_type;
1011

1112
use crate::take::take_canonical_array;
1213

@@ -52,6 +53,23 @@ pub fn filter_canonical_array(array: &dyn Array, filter: &[bool]) -> VortexResul
5253
)
5354
.into_array())
5455
}),
56+
DType::Decimal(d, _) => {
57+
let decimal_array = array.to_decimal()?;
58+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
59+
let buf = decimal_array.buffer::<D>();
60+
Ok(DecimalArray::new(
61+
filter
62+
.iter()
63+
.zip(buf.as_slice().iter().copied())
64+
.filter(|(f, _)| **f)
65+
.map(|(_, v)| v)
66+
.collect::<Buffer<_>>(),
67+
*d,
68+
validity,
69+
)
70+
.into_array())
71+
})
72+
}
5573
DType::Utf8(_) | DType::Binary(_) => {
5674
let utf8 = array.to_varbinview()?;
5775
let values = utf8.with_iterator(|iter| {
@@ -87,6 +105,8 @@ pub fn filter_canonical_array(array: &dyn Array, filter: &[bool]) -> VortexResul
87105
}
88106
take_canonical_array(array, &indices)
89107
}
90-
d => unreachable!("DType {d} not supported for fuzzing"),
108+
d @ (DType::Null | DType::Extension(_)) => {
109+
unreachable!("DType {d} not supported for fuzzing")
110+
}
91111
}
92112
}

fuzz/src/search_sorted.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use vortex_array::search_sorted::{IndexOrd, SearchResult, SearchSorted, SearchSo
66
use vortex_array::{Array, ToCanonical};
77
use vortex_buffer::{BufferString, ByteBuffer};
88
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9-
use vortex_error::VortexResult;
10-
use vortex_scalar::Scalar;
9+
use vortex_error::{VortexResult, vortex_err};
10+
use vortex_scalar::{Scalar, match_each_decimal_value_type};
1111

1212
struct SearchNullableSlice<T>(Vec<Option<T>>);
1313

@@ -76,6 +76,31 @@ pub fn search_sorted_canonical_array(
7676
Ok(SearchPrimitiveSlice(opt_values).search_sorted(&Some(to_find), side))
7777
})
7878
}
79+
DType::Decimal(d, _) => {
80+
let decimal_array = array.to_decimal()?;
81+
let validity = decimal_array.validity_mask()?.to_boolean_buffer();
82+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
83+
let buf = decimal_array.buffer::<D>();
84+
let opt_values = buf
85+
.as_slice()
86+
.iter()
87+
.copied()
88+
.zip(validity.iter())
89+
.map(|(b, v)| v.then_some(b))
90+
.collect::<Vec<_>>();
91+
let to_find: D = scalar
92+
.as_decimal()
93+
.decimal_value()
94+
.map(|v| {
95+
v.cast::<D>().ok_or_else(|| {
96+
vortex_err!("cannot cast value {v} to decimal value type {d}")
97+
})
98+
})
99+
.transpose()?
100+
.ok_or_else(|| vortex_err!("unexpected null scalar"))?;
101+
Ok(SearchNullableSlice(opt_values).search_sorted(&Some(to_find), side))
102+
})
103+
}
79104
DType::Utf8(_) | DType::Binary(_) => {
80105
let utf8 = array.to_varbinview()?;
81106
let opt_values =
@@ -99,6 +124,8 @@ pub fn search_sorted_canonical_array(
99124
.collect::<VortexResult<Vec<_>>>()?;
100125
Ok(scalar_vals.search_sorted(&scalar.cast(array.dtype())?, side))
101126
}
102-
d => unreachable!("DType {d} not supported for fuzzing"),
127+
d @ (DType::Null | DType::Extension(_)) => {
128+
unreachable!("DType {d} not supported for fuzzing")
129+
}
103130
}
104131
}

fuzz/src/slice.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use vortex_array::accessor::ArrayAccessor;
2-
use vortex_array::arrays::{BoolArray, ListArray, PrimitiveArray, StructArray, VarBinViewArray};
2+
use vortex_array::arrays::{
3+
BoolArray, DecimalArray, ListArray, PrimitiveArray, StructArray, VarBinViewArray,
4+
};
35
use vortex_array::validity::Validity;
46
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
57
use vortex_dtype::{
68
DType, NativePType, Nullability, match_each_integer_ptype, match_each_native_ptype,
79
};
810
use vortex_error::VortexResult;
11+
use vortex_scalar::match_each_decimal_value_type;
912

1013
#[allow(clippy::unnecessary_fallible_conversions)]
1114
pub fn slice_canonical_array(
@@ -80,7 +83,22 @@ pub fn slice_canonical_array(
8083
.into_array();
8184
ListArray::try_new(elements, offsets, validity).map(|a| a.into_array())
8285
}
83-
d => unreachable!("DType {d} not supported for fuzzing"),
86+
DType::Decimal(decimal_dtype, _) => {
87+
let decimal_array = array.to_decimal()?;
88+
Ok(
89+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
90+
DecimalArray::new(
91+
decimal_array.buffer::<D>().slice(start..stop),
92+
*decimal_dtype,
93+
validity,
94+
)
95+
})
96+
.to_array(),
97+
)
98+
}
99+
d @ (DType::Null | DType::Extension(_)) => {
100+
unreachable!("DType {d} not supported for fuzzing")
101+
}
84102
}
85103
}
86104

fuzz/src/sort.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use std::cmp::Ordering;
22

33
use vortex_array::accessor::ArrayAccessor;
4-
use vortex_array::arrays::{BoolArray, PrimitiveArray, VarBinViewArray};
4+
use vortex_array::arrays::{BoolArray, DecimalArray, PrimitiveArray, VarBinViewArray};
55
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
66
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
77
use vortex_error::{VortexExpect, VortexResult, VortexUnwrap};
8+
use vortex_scalar::match_each_decimal_value_type;
89

910
use crate::take::take_canonical_array;
1011

@@ -35,6 +36,21 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
3536
Ok(PrimitiveArray::from_option_iter(opt_values).into_array())
3637
})
3738
}
39+
DType::Decimal(d, _) => {
40+
let decimal_array = array.to_decimal()?;
41+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
42+
let buf = decimal_array.buffer::<D>();
43+
let mut opt_values = buf
44+
.as_slice()
45+
.iter()
46+
.copied()
47+
.zip(decimal_array.validity_mask()?.to_boolean_buffer().iter())
48+
.map(|(p, v)| v.then_some(p))
49+
.collect::<Vec<_>>();
50+
opt_values.sort();
51+
Ok(DecimalArray::from_option_iter(opt_values, *d).into_array())
52+
})
53+
}
3854
DType::Utf8(_) | DType::Binary(_) => {
3955
let utf8 = array.to_varbinview()?;
4056
let mut opt_values =
@@ -64,7 +80,9 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
6480
});
6581
take_canonical_array(array, &sort_indices)
6682
}
67-
d => unreachable!("DType {d} not supported for fuzzing"),
83+
d @ (DType::Null | DType::Extension(_)) => {
84+
unreachable!("DType {d} not supported for fuzzing")
85+
}
6886
}
6987
}
7088

fuzz/src/take.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use arrow_buffer::ArrowNativeType;
22
use vortex_array::accessor::ArrayAccessor;
3-
use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray, VarBinViewArray};
3+
use vortex_array::arrays::{BoolArray, DecimalArray, PrimitiveArray, StructArray, VarBinViewArray};
44
use vortex_array::builders::{ArrayBuilderExt, builder_with_capacity};
55
use vortex_array::validity::Validity;
66
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
77
use vortex_buffer::Buffer;
8-
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
8+
use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype};
99
use vortex_error::VortexResult;
10+
use vortex_scalar::{NativeDecimalType, match_each_decimal_value_type};
1011

1112
pub fn take_canonical_array(array: &dyn Array, indices: &[usize]) -> VortexResult<ArrayRef> {
1213
let validity = if array.dtype().is_nullable() {
@@ -32,6 +33,13 @@ pub fn take_canonical_array(array: &dyn Array, indices: &[usize]) -> VortexResul
3233
Ok(take_primitive::<P>(primitive_array, validity, indices))
3334
})
3435
}
36+
DType::Decimal(d, _) => {
37+
let decimal_array = array.to_decimal()?;
38+
39+
match_each_decimal_value_type!(decimal_array.values_type(), |D| {
40+
Ok(take_decimal::<D>(decimal_array, d, validity, indices))
41+
})
42+
}
3543
DType::Utf8(_) | DType::Binary(_) => {
3644
let utf8 = array.to_varbinview()?;
3745
let values =
@@ -65,7 +73,9 @@ pub fn take_canonical_array(array: &dyn Array, indices: &[usize]) -> VortexResul
6573
}
6674
Ok(builder.finish())
6775
}
68-
d => unreachable!("DType {d} not supported for fuzzing"),
76+
d @ (DType::Null | DType::Extension(_)) => {
77+
unreachable!("DType {d} not supported for fuzzing")
78+
}
6979
}
7080
}
7181

@@ -84,3 +94,22 @@ fn take_primitive<T: NativePType + ArrowNativeType>(
8494
)
8595
.into_array()
8696
}
97+
98+
fn take_decimal<D: NativeDecimalType>(
99+
array: DecimalArray,
100+
decimal_type: &DecimalDType,
101+
validity: Validity,
102+
indices: &[usize],
103+
) -> ArrayRef {
104+
let buf = array.buffer::<D>();
105+
let vec_values = buf.as_slice();
106+
DecimalArray::new(
107+
indices
108+
.iter()
109+
.map(|i| vec_values[*i])
110+
.collect::<Buffer<D>>(),
111+
*decimal_type,
112+
validity,
113+
)
114+
.into_array()
115+
}

vortex-array/src/arrays/arbitrary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ fn random_array(u: &mut Unstructured, dtype: &DType, len: Option<usize>) -> Resu
7070
PType::F64 => random_primitive::<f64>(u, *n, chunk_len),
7171
},
7272
DType::Decimal(decimal, n) => {
73-
let elem_len = u.int_in_range(0..=20)?;
73+
let elem_len = chunk_len.unwrap_or(u.int_in_range(0..=20)?);
7474
match_each_decimal_value_type!(smallest_storage_type(decimal), |DVT| {
7575
let mut builder =
7676
DecimalBuilder::new::<DVT>(decimal.precision(), decimal.scale(), *n);

vortex-array/src/arrays/decimal/mod.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ mod macros;
33
mod ops;
44
mod serde;
55

6-
use vortex_buffer::{Buffer, ByteBuffer};
6+
use arrow_buffer::BooleanBufferBuilder;
7+
use vortex_buffer::{Buffer, BufferMut, ByteBuffer};
78
use vortex_dtype::{DType, DecimalDType};
89
use vortex_error::{VortexResult, vortex_panic};
910
use vortex_scalar::{DecimalValueType, NativeDecimalType};
@@ -130,6 +131,33 @@ impl DecimalArray {
130131
pub fn scale(&self) -> i8 {
131132
self.decimal_dtype().scale()
132133
}
134+
135+
pub fn from_option_iter<T: NativeDecimalType, I: IntoIterator<Item = Option<T>>>(
136+
iter: I,
137+
decimal_dtype: DecimalDType,
138+
) -> Self {
139+
let iter = iter.into_iter();
140+
let mut values = BufferMut::with_capacity(iter.size_hint().0);
141+
let mut validity = BooleanBufferBuilder::new(values.capacity());
142+
143+
for i in iter {
144+
match i {
145+
None => {
146+
validity.append(false);
147+
values.push(T::default());
148+
}
149+
Some(e) => {
150+
validity.append(true);
151+
values.push(e);
152+
}
153+
}
154+
}
155+
Self::new(
156+
values.freeze(),
157+
decimal_dtype,
158+
Validity::from(validity.finish()),
159+
)
160+
}
133161
}
134162

135163
impl ArrayVTable<DecimalVTable> for DecimalVTable {

vortex-array/src/arrays/decimal/ops.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{ArrayRef, IntoArray};
1010

1111
impl OperationsVTable<DecimalVTable> for DecimalVTable {
1212
fn slice(array: &DecimalArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
13-
match_each_decimal_value_type!(array.values_type, |D| {
13+
match_each_decimal_value_type!(array.values_type(), |D| {
1414
slice_typed(
1515
array.buffer::<D>(),
1616
start,

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl StructArray {
105105
if field.len() != length {
106106
vortex_bail!(
107107
"Expected all struct fields to have length {length}, found {}",
108-
field.len()
108+
fields.iter().map(|f| f.len()).format(","),
109109
);
110110
}
111111
}

0 commit comments

Comments
 (0)