Skip to content

Commit 736cee6

Browse files
authored
feat: Implement all of the compute fns for DecimalArray (#3116)
Follow on to #3058 In the previous PR, several compute functions were left unimplemented. This closes out and implements the remainder needed for full support of DecimalArray and Decimal types. All compute functions have some associated unit tests
1 parent 768e27c commit 736cee6

File tree

11 files changed

+840
-27
lines changed

11 files changed

+840
-27
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
use arrow_buffer::BooleanBuffer;
2+
use vortex_dtype::Nullability;
3+
use vortex_error::{VortexResult, vortex_bail};
4+
use vortex_scalar::{DecimalValue, i256};
5+
6+
use crate::arrays::decimal::serde::DecimalValueType;
7+
use crate::arrays::{BoolArray, DecimalArray, DecimalEncoding, NativeDecimalType};
8+
use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
9+
use crate::{Array, ArrayRef, register_kernel};
10+
11+
impl BetweenKernel for DecimalEncoding {
12+
// Determine if the values are between the lower and upper bounds
13+
fn between(
14+
&self,
15+
arr: &DecimalArray,
16+
lower: &dyn Array,
17+
upper: &dyn Array,
18+
options: &BetweenOptions,
19+
) -> VortexResult<Option<ArrayRef>> {
20+
// NOTE: We know that the precision and scale were already checked to be equal by the main
21+
// `between` entrypoint function.
22+
23+
let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24+
return Ok(None);
25+
};
26+
27+
// NOTE: we know that have checked before that the lower and upper bounds are not all null.
28+
let nullability =
29+
arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30+
31+
match arr.values_type {
32+
DecimalValueType::I128 => {
33+
let Some(DecimalValue::I128(lower_128)) = *lower.as_decimal().decimal_value()
34+
else {
35+
vortex_bail!("invalid lower bound Scalar: {lower}");
36+
};
37+
let Some(DecimalValue::I128(upper_128)) = *upper.as_decimal().decimal_value()
38+
else {
39+
vortex_bail!("invalid upper bound Scalar: {upper}");
40+
};
41+
42+
let lower_op = match options.lower_strict {
43+
StrictComparison::Strict => i128_lt_i128,
44+
StrictComparison::NonStrict => i128_lte_i128,
45+
};
46+
47+
let upper_op = match options.upper_strict {
48+
StrictComparison::Strict => i128_lt_i128,
49+
StrictComparison::NonStrict => i128_lte_i128,
50+
};
51+
52+
Ok(Some(between_impl::<i128>(
53+
arr,
54+
lower_128,
55+
upper_128,
56+
nullability,
57+
lower_op,
58+
upper_op,
59+
)))
60+
}
61+
DecimalValueType::I256 => {
62+
let Some(DecimalValue::I256(lower_256)) = *lower.as_decimal().decimal_value()
63+
else {
64+
vortex_bail!("invalid lower bound Scalar: {lower}");
65+
};
66+
let Some(DecimalValue::I256(upper_256)) = *upper.as_decimal().decimal_value()
67+
else {
68+
vortex_bail!("invalid upper bound Scalar: {upper}");
69+
};
70+
71+
let lower_op = match options.lower_strict {
72+
StrictComparison::Strict => i256_lt_i256,
73+
StrictComparison::NonStrict => i256_lte_i256,
74+
};
75+
76+
let upper_op = match options.upper_strict {
77+
StrictComparison::Strict => i256_lt_i256,
78+
StrictComparison::NonStrict => i256_lte_i256,
79+
};
80+
81+
Ok(Some(between_impl::<i256>(
82+
arr,
83+
lower_256,
84+
upper_256,
85+
nullability,
86+
lower_op,
87+
upper_op,
88+
)))
89+
}
90+
}
91+
}
92+
}
93+
94+
register_kernel!(BetweenKernelAdapter(DecimalEncoding).lift());
95+
96+
fn between_impl<T: NativeDecimalType>(
97+
arr: &DecimalArray,
98+
lower: T,
99+
upper: T,
100+
nullability: Nullability,
101+
lower_op: fn(T, T) -> bool,
102+
upper_op: fn(T, T) -> bool,
103+
) -> ArrayRef {
104+
let buffer = arr.buffer::<T>();
105+
BoolArray::new(
106+
BooleanBuffer::collect_bool(buffer.len(), |idx| {
107+
let value = buffer[idx];
108+
lower_op(lower, value) & upper_op(value, upper)
109+
}),
110+
arr.validity().clone().union_nullability(nullability),
111+
)
112+
.into_array()
113+
}
114+
115+
#[inline]
116+
const fn i128_lt_i128(a: i128, b: i128) -> bool {
117+
a < b
118+
}
119+
120+
#[inline]
121+
const fn i128_lte_i128(a: i128, b: i128) -> bool {
122+
a <= b
123+
}
124+
125+
#[inline]
126+
fn i256_lt_i256(a: i256, b: i256) -> bool {
127+
a < b
128+
}
129+
130+
#[inline]
131+
fn i256_lte_i256(a: i256, b: i256) -> bool {
132+
a <= b
133+
}
134+
135+
#[cfg(test)]
136+
mod tests {
137+
use vortex_buffer::buffer;
138+
use vortex_dtype::{DecimalDType, Nullability};
139+
use vortex_scalar::{DecimalValue, Scalar};
140+
141+
use crate::Array;
142+
use crate::arrays::{ConstantArray, DecimalArray};
143+
use crate::compute::{BetweenOptions, StrictComparison, between};
144+
use crate::validity::Validity;
145+
146+
#[test]
147+
fn test_between() {
148+
let values = buffer![100i128, 200i128, 300i128, 400i128];
149+
let decimal_type = DecimalDType::new(3, 2);
150+
let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
151+
152+
let lower = ConstantArray::new(
153+
Scalar::decimal(
154+
DecimalValue::I128(100i128),
155+
decimal_type,
156+
Nullability::NonNullable,
157+
),
158+
array.len(),
159+
);
160+
let upper = ConstantArray::new(
161+
Scalar::decimal(
162+
DecimalValue::I128(400i128),
163+
decimal_type,
164+
Nullability::NonNullable,
165+
),
166+
array.len(),
167+
);
168+
169+
// Strict lower bound, non-strict upper bound
170+
let between_strict = between(
171+
&array,
172+
&lower,
173+
&upper,
174+
&BetweenOptions {
175+
lower_strict: StrictComparison::Strict,
176+
upper_strict: StrictComparison::NonStrict,
177+
},
178+
)
179+
.unwrap();
180+
assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
181+
182+
// Non-strict lower bound, strict upper bound
183+
let between_strict = between(
184+
&array,
185+
&lower,
186+
&upper,
187+
&BetweenOptions {
188+
lower_strict: StrictComparison::NonStrict,
189+
upper_strict: StrictComparison::Strict,
190+
},
191+
)
192+
.unwrap();
193+
assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
194+
}
195+
196+
fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
197+
array
198+
.to_canonical()
199+
.unwrap()
200+
.into_bool()
201+
.unwrap()
202+
.boolean_buffer()
203+
.iter()
204+
.collect()
205+
}
206+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use vortex_error::VortexResult;
2+
use vortex_scalar::i256;
3+
4+
use crate::arrays::decimal::serde::DecimalValueType;
5+
use crate::arrays::{DecimalArray, DecimalEncoding, NativeDecimalType};
6+
use crate::compute::{IsConstantFn, IsConstantOpts};
7+
8+
impl IsConstantFn<&DecimalArray> for DecimalEncoding {
9+
fn is_constant(
10+
&self,
11+
array: &DecimalArray,
12+
_opts: &IsConstantOpts,
13+
) -> VortexResult<Option<bool>> {
14+
match array.values_type {
15+
DecimalValueType::I128 => Ok(Some(compute_is_constant(&array.buffer::<i128>()))),
16+
DecimalValueType::I256 => Ok(Some(compute_is_constant(&array.buffer::<i256>()))),
17+
}
18+
}
19+
}
20+
21+
fn compute_is_constant<T: NativeDecimalType>(values: &[T]) -> bool {
22+
// We know that the top-level `is_constant` ensures that the array is all_valid or non-null.
23+
let first_value = values[0];
24+
25+
for &value in &values[1..] {
26+
if value != first_value {
27+
return false;
28+
}
29+
}
30+
31+
true
32+
}
33+
34+
#[cfg(test)]
35+
mod tests {
36+
use vortex_buffer::buffer;
37+
use vortex_dtype::DecimalDType;
38+
39+
use crate::arrays::DecimalArray;
40+
use crate::compute::is_constant;
41+
use crate::validity::Validity;
42+
43+
#[test]
44+
fn test_is_constant() {
45+
let array = DecimalArray::new(
46+
buffer![0i128, 1i128, 2i128],
47+
DecimalDType::new(19, 0),
48+
Validity::NonNullable,
49+
);
50+
51+
assert!(!is_constant(&array).unwrap());
52+
53+
let array = DecimalArray::new(
54+
buffer![100i128, 100i128, 100i128],
55+
DecimalDType::new(19, 0),
56+
Validity::NonNullable,
57+
);
58+
59+
assert!(is_constant(&array).unwrap());
60+
}
61+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::Mask;
3+
use vortex_scalar::i256;
4+
5+
use crate::Array;
6+
use crate::arrays::decimal::serde::DecimalValueType;
7+
use crate::arrays::{DecimalArray, DecimalEncoding, NativeDecimalType};
8+
use crate::compute::{IsSortedFn, IsSortedIteratorExt};
9+
10+
impl IsSortedFn<&DecimalArray> for DecimalEncoding {
11+
fn is_sorted(&self, array: &DecimalArray) -> VortexResult<bool> {
12+
match array.values_type {
13+
DecimalValueType::I128 => compute_is_sorted::<i128>(array, false),
14+
DecimalValueType::I256 => compute_is_sorted::<i256>(array, false),
15+
}
16+
}
17+
18+
fn is_strict_sorted(&self, array: &DecimalArray) -> VortexResult<bool> {
19+
match array.values_type {
20+
DecimalValueType::I128 => compute_is_sorted::<i128>(array, true),
21+
DecimalValueType::I256 => compute_is_sorted::<i256>(array, true),
22+
}
23+
}
24+
}
25+
26+
fn compute_is_sorted<T: NativeDecimalType>(array: &DecimalArray, strict: bool) -> VortexResult<bool>
27+
where
28+
dyn Iterator<Item = T>: IsSortedIteratorExt,
29+
{
30+
match array.validity_mask()? {
31+
Mask::AllFalse(_) => Ok(!strict),
32+
Mask::AllTrue(_) => {
33+
let buf = array.buffer::<T>();
34+
let iter = buf.iter().copied();
35+
36+
Ok(if strict {
37+
IsSortedIteratorExt::is_strict_sorted(iter)
38+
} else {
39+
iter.is_sorted()
40+
})
41+
}
42+
Mask::Values(mask_values) => {
43+
let buf = array.buffer::<T>();
44+
45+
let iter = mask_values
46+
.boolean_buffer()
47+
.set_indices()
48+
.map(|idx| buf[idx]);
49+
50+
Ok(if strict {
51+
IsSortedIteratorExt::is_strict_sorted(iter)
52+
} else {
53+
iter.is_sorted()
54+
})
55+
}
56+
}
57+
}
58+
59+
#[cfg(test)]
60+
mod tests {
61+
use vortex_buffer::buffer;
62+
use vortex_dtype::DecimalDType;
63+
64+
use crate::arrays::DecimalArray;
65+
use crate::compute::{is_sorted, is_strict_sorted};
66+
use crate::validity::Validity;
67+
68+
#[test]
69+
fn test_is_sorted() {
70+
let sorted = buffer![100i128, 200i128, 200i128];
71+
let unsorted = buffer![200i128, 100i128, 200i128];
72+
73+
let dtype = DecimalDType::new(19, 2);
74+
75+
let sorted_array = DecimalArray::new(sorted, dtype, Validity::NonNullable);
76+
let unsorted_array = DecimalArray::new(unsorted, dtype, Validity::NonNullable);
77+
78+
assert!(is_sorted(&sorted_array).unwrap());
79+
assert!(!is_sorted(&unsorted_array).unwrap());
80+
}
81+
82+
#[test]
83+
fn test_is_strict_sorted() {
84+
let strict_sorted = buffer![100i128, 200i128, 300i128];
85+
let sorted = buffer![100i128, 200i128, 200i128];
86+
87+
let dtype = DecimalDType::new(19, 2);
88+
89+
let strict_sorted_array = DecimalArray::new(strict_sorted, dtype, Validity::NonNullable);
90+
let sorted_array = DecimalArray::new(sorted, dtype, Validity::NonNullable);
91+
92+
assert!(is_strict_sorted(&strict_sorted_array).unwrap());
93+
assert!(!is_strict_sorted(&sorted_array).unwrap());
94+
}
95+
}

0 commit comments

Comments
 (0)