Skip to content

Commit fdccca6

Browse files
bug[decimal-byte-parts]: unconvertible scalar in compare (#3578)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 655d318 commit fdccca6

File tree

2 files changed

+68
-13
lines changed

2 files changed

+68
-13
lines changed

encodings/decimal-byte-parts/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ vortex-dtype = { workspace = true }
2626
vortex-error = { workspace = true }
2727
vortex-mask = { workspace = true }
2828
vortex-scalar = { workspace = true }
29+
30+
[dev-dependencies]
31+
vortex-array = { path = "../../vortex-array", features = ["test-harness"] }

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use num_traits::NumCast;
22
use vortex_array::arrays::ConstantArray;
33
use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
44
use vortex_array::{Array, ArrayRef, register_kernel};
5-
use vortex_dtype::{NativePType, PType, match_each_integer_ptype};
6-
use vortex_error::VortexResult;
5+
use vortex_dtype::{NativePType, Nullability, PType, match_each_integer_ptype};
6+
use vortex_error::{VortexExpect, VortexResult};
77
use vortex_scalar::{DecimalValue, Scalar, ScalarValue, match_each_decimal_value};
88

99
use crate::DecimalBytePartsVTable;
@@ -23,25 +23,40 @@ impl CompareKernel for DecimalBytePartsVTable {
2323
return Ok(None);
2424
};
2525

26-
let scalar_type = lhs
27-
.msp
28-
.dtype()
29-
.with_nullability(lhs.dtype.nullability() | rhs.dtype().nullability());
26+
let nullability = lhs.dtype.nullability() | rhs.dtype().nullability();
27+
let scalar_type = lhs.msp.dtype().with_nullability(nullability);
3028

31-
let encoded_scalar = rhs_const
29+
let rhs_decimal = rhs_const
3230
.as_decimal()
3331
.decimal_value()
34-
.and_then(|value| {
35-
decimal_value_wrapper_to_primitive(value, lhs.msp.as_primitive_typed().ptype())
36-
})
37-
.map(|value| Scalar::new(scalar_type.clone(), value))
38-
.unwrap_or_else(|| Scalar::null(scalar_type));
32+
.vortex_expect("checked for null in entry func");
33+
let Some(encoded_scalar) =
34+
decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype())
35+
.map(|value| Scalar::new(scalar_type.clone(), value))
36+
else {
37+
// here the scalar value is bigger than the msp type.
38+
// TODO(joe): fixme, when allowing lsp values.
39+
return Ok(Some(
40+
ConstantArray::new(unconvertible_value(operator, nullability), lhs.len())
41+
.to_array(),
42+
));
43+
};
3944
let encoded_const = ConstantArray::new(encoded_scalar, rhs.len());
4045
compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some)
4146
}
4247
}
4348

44-
// clippy prefers smaller functions
49+
fn unconvertible_value(operator: Operator, nullability: Nullability) -> Scalar {
50+
// v op unconvertible where unconvertible > v_max
51+
match operator {
52+
// v is never eq or gt/gte
53+
Operator::Eq | Operator::Gt | Operator::Gte => Scalar::bool(false, nullability),
54+
// v is always eq or gt/gte
55+
Operator::NotEq | Operator::Lt | Operator::Lte => Scalar::bool(true, nullability),
56+
}
57+
}
58+
59+
// this value return None is the decimal scalar cannot be cast the ptype.
4560
fn decimal_value_wrapper_to_primitive(
4661
decimal_value: DecimalValue,
4762
ptype: PType,
@@ -99,4 +114,41 @@ mod tests {
99114
vec![false, false, true]
100115
);
101116
}
117+
118+
#[test]
119+
fn compare_decimal_const_unconvertible_comparison() {
120+
let decimal_dtype = DecimalDType::new(40, 2);
121+
let dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
122+
let lhs = DecimalBytePartsArray::try_new(
123+
PrimitiveArray::new(buffer![100i32, 200i32, 400i32], Validity::AllValid).to_array(),
124+
vec![],
125+
decimal_dtype,
126+
)
127+
.unwrap()
128+
.to_array();
129+
// This cannot be converted to a i32.
130+
let rhs = ConstantArray::new(
131+
Scalar::new(dtype, DecimalValue::I128(-9999999999999965304).into()),
132+
lhs.len(),
133+
);
134+
135+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
136+
137+
assert_eq!(
138+
res.to_bool().unwrap().bool_vec().unwrap(),
139+
vec![false, false, false]
140+
);
141+
142+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
143+
assert_eq!(
144+
res.to_bool().unwrap().bool_vec().unwrap(),
145+
vec![false, false, false]
146+
);
147+
148+
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
149+
assert_eq!(
150+
res.to_bool().unwrap().bool_vec().unwrap(),
151+
vec![true, true, true]
152+
);
153+
}
102154
}

0 commit comments

Comments
 (0)