Skip to content

Commit dae960a

Browse files
fix[encodings]: decimal-byte-parts-array compare when scalar is larger than any value in the array (#3650)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 41cb1b6 commit dae960a

File tree

1 file changed

+81
-23
lines changed
  • encodings/decimal-byte-parts/src/decimal_byte_parts/compute

1 file changed

+81
-23
lines changed

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use num_traits::NumCast;
1+
use Sign::Negative;
2+
use num_traits::{Bounded, NumCast};
23
use vortex_array::arrays::ConstantArray;
34
use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
45
use vortex_array::{Array, ArrayRef, register_kernel};
56
use vortex_dtype::{NativePType, Nullability, PType, match_each_integer_ptype};
67
use vortex_error::{VortexExpect, VortexResult};
7-
use vortex_scalar::{DecimalValue, Scalar, ScalarValue, match_each_decimal_value};
8+
use vortex_scalar::{DecimalValue, Scalar, ScalarValue, ToPrimitive, match_each_decimal_value};
89

910
use crate::DecimalBytePartsVTable;
11+
use crate::decimal_byte_parts::compute::compare::Sign::Positive;
1012

1113
impl CompareKernel for DecimalBytePartsVTable {
1214
fn compare(
@@ -30,49 +32,77 @@ impl CompareKernel for DecimalBytePartsVTable {
3032
.as_decimal()
3133
.decimal_value()
3234
.vortex_expect("checked for null in entry func");
33-
let Some(encoded_scalar) =
34-
decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype())
35-
.map(|value| Scalar::new(scalar_type.clone(), value))
36-
else {
35+
36+
match decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype())
37+
.map(|value| Scalar::new(scalar_type.clone(), value))
38+
{
39+
Ok(encoded_scalar) => {
40+
let encoded_const = ConstantArray::new(encoded_scalar, rhs.len());
41+
compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some)
42+
}
3743
// here the scalar value is bigger than the msp type.
3844
// TODO(joe): fixme, when allowing lsp values.
39-
return Ok(Some(
40-
ConstantArray::new(unconvertible_value(operator, nullability), lhs.len())
45+
Err(sign) => Ok(Some(
46+
ConstantArray::new(unconvertible_value(sign, operator, nullability), lhs.len())
4147
.to_array(),
42-
));
43-
};
44-
let encoded_const = ConstantArray::new(encoded_scalar, rhs.len());
45-
compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some)
48+
)),
49+
}
4650
}
4751
}
4852

49-
fn unconvertible_value(operator: Operator, nullability: Nullability) -> Scalar {
50-
// v op unconvertible where unconvertible > v_max
53+
// Used to represent the overflow direction when trying to
54+
// convert into the scalar type.
55+
enum Sign {
56+
Positive,
57+
Negative,
58+
}
59+
60+
fn unconvertible_value(sign: Sign, operator: Operator, nullability: Nullability) -> Scalar {
5161
match operator {
52-
// v is never eq or gt/gte
53-
Operator::Eq | Operator::Gt | Operator::Gte => Scalar::bool(false, nullability),
54-
// v is always eq or gt/gte
55-
Operator::NotEq | Operator::Lt | Operator::Lte => Scalar::bool(true, nullability),
62+
Operator::Eq => Scalar::bool(false, nullability),
63+
Operator::NotEq => Scalar::bool(true, nullability),
64+
Operator::Gt | Operator::Gte => Scalar::bool(matches!(sign, Positive), nullability),
65+
Operator::Lt | Operator::Lte => Scalar::bool(matches!(sign, Negative), nullability),
5666
}
5767
}
5868

5969
// this value return None is the decimal scalar cannot be cast the ptype.
6070
fn decimal_value_wrapper_to_primitive(
6171
decimal_value: DecimalValue,
6272
ptype: PType,
63-
) -> Option<ScalarValue> {
73+
) -> Result<ScalarValue, Sign> {
6474
match_each_integer_ptype!(ptype, |P| {
6575
decimal_value_to_primitive::<P>(decimal_value)
6676
})
6777
}
6878

69-
fn decimal_value_to_primitive<P>(decimal_value: DecimalValue) -> Option<ScalarValue>
79+
fn decimal_value_to_primitive<P>(decimal_value: DecimalValue) -> Result<ScalarValue, Sign>
7080
where
71-
P: NativePType + NumCast,
81+
P: NativePType + NumCast + Bounded + ToPrimitive,
7282
ScalarValue: From<P>,
7383
{
7484
match_each_decimal_value!(decimal_value, |decimal_v| {
75-
Some(ScalarValue::from(<P as NumCast>::from(decimal_v)?))
85+
let Some(encoded) = <P as NumCast>::from(decimal_v) else {
86+
let decimal_i256 = decimal_v
87+
.to_i256()
88+
.vortex_expect("i256 is big enough for any DecimalValue");
89+
return if decimal_i256
90+
> P::max_value()
91+
.to_i256()
92+
.vortex_expect("i256 is big enough for any PType")
93+
{
94+
Err(Positive)
95+
} else {
96+
assert!(
97+
decimal_i256
98+
< P::min_value()
99+
.to_i256()
100+
.vortex_expect("i256 is big enough for any PType")
101+
);
102+
Err(Negative)
103+
};
104+
};
105+
Ok(ScalarValue::from(encoded))
76106
})
77107
}
78108

@@ -128,7 +158,10 @@ mod tests {
128158
.to_array();
129159
// This cannot be converted to a i32.
130160
let rhs = ConstantArray::new(
131-
Scalar::new(dtype, DecimalValue::I128(-9999999999999965304).into()),
161+
Scalar::new(
162+
dtype.clone(),
163+
DecimalValue::I128(-9999999999999965304).into(),
164+
),
132165
lhs.len(),
133166
);
134167

@@ -150,5 +183,30 @@ mod tests {
150183
res.to_bool().unwrap().bool_vec().unwrap(),
151184
vec![true, true, true]
152185
);
186+
187+
// This cannot be converted to a i32.
188+
let rhs = ConstantArray::new(
189+
Scalar::new(dtype, DecimalValue::I128(9999999999999965304).into()),
190+
lhs.len(),
191+
);
192+
193+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
194+
195+
assert_eq!(
196+
res.to_bool().unwrap().bool_vec().unwrap(),
197+
vec![false, false, false]
198+
);
199+
200+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
201+
assert_eq!(
202+
res.to_bool().unwrap().bool_vec().unwrap(),
203+
vec![true, true, true]
204+
);
205+
206+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
207+
assert_eq!(
208+
res.to_bool().unwrap().bool_vec().unwrap(),
209+
vec![false, false, false]
210+
);
153211
}
154212
}

0 commit comments

Comments
 (0)