@@ -2,8 +2,8 @@ use num_traits::NumCast;
22use vortex_array:: arrays:: ConstantArray ;
33use vortex_array:: compute:: { CompareKernel , CompareKernelAdapter , Operator , compare} ;
44use vortex_array:: { Array , ArrayRef , register_kernel} ;
5- use vortex_dtype:: { NativePType , PType , match_each_integer_ptype} ;
6- use vortex_error:: VortexResult ;
5+ use vortex_dtype:: { NativePType , Nullability , PType , match_each_integer_ptype} ;
6+ use vortex_error:: { VortexExpect , VortexResult } ;
77use vortex_scalar:: { DecimalValue , Scalar , ScalarValue , match_each_decimal_value} ;
88
99use crate :: DecimalBytePartsVTable ;
@@ -23,25 +23,40 @@ impl CompareKernel for DecimalBytePartsVTable {
2323 return Ok ( None ) ;
2424 } ;
2525
26- let scalar_type = lhs
27- . msp
28- . dtype ( )
29- . with_nullability ( lhs. dtype . nullability ( ) | rhs. dtype ( ) . nullability ( ) ) ;
26+ let nullability = lhs. dtype . nullability ( ) | rhs. dtype ( ) . nullability ( ) ;
27+ let scalar_type = lhs. msp . dtype ( ) . with_nullability ( nullability) ;
3028
31- let encoded_scalar = rhs_const
29+ let rhs_decimal = rhs_const
3230 . as_decimal ( )
3331 . decimal_value ( )
34- . and_then ( |value| {
35- decimal_value_wrapper_to_primitive ( value, lhs. msp . as_primitive_typed ( ) . ptype ( ) )
36- } )
37- . map ( |value| Scalar :: new ( scalar_type. clone ( ) , value) )
38- . unwrap_or_else ( || Scalar :: null ( scalar_type) ) ;
32+ . vortex_expect ( "checked for null in entry func" ) ;
33+ let Some ( encoded_scalar) =
34+ decimal_value_wrapper_to_primitive ( rhs_decimal, lhs. msp . as_primitive_typed ( ) . ptype ( ) )
35+ . map ( |value| Scalar :: new ( scalar_type. clone ( ) , value) )
36+ else {
37+ // here the scalar value is bigger than the msp type.
38+ // TODO(joe): fixme, when allowing lsp values.
39+ return Ok ( Some (
40+ ConstantArray :: new ( unconvertible_value ( operator, nullability) , lhs. len ( ) )
41+ . to_array ( ) ,
42+ ) ) ;
43+ } ;
3944 let encoded_const = ConstantArray :: new ( encoded_scalar, rhs. len ( ) ) ;
4045 compare ( & lhs. msp , & encoded_const. to_array ( ) , operator) . map ( Some )
4146 }
4247}
4348
44- // clippy prefers smaller functions
49+ fn unconvertible_value ( operator : Operator , nullability : Nullability ) -> Scalar {
50+ // v op unconvertible where unconvertible > v_max
51+ match operator {
52+ // v is never eq or gt/gte
53+ Operator :: Eq | Operator :: Gt | Operator :: Gte => Scalar :: bool ( false , nullability) ,
54+ // v is always eq or gt/gte
55+ Operator :: NotEq | Operator :: Lt | Operator :: Lte => Scalar :: bool ( true , nullability) ,
56+ }
57+ }
58+
59+ // this value return None is the decimal scalar cannot be cast the ptype.
4560fn decimal_value_wrapper_to_primitive (
4661 decimal_value : DecimalValue ,
4762 ptype : PType ,
@@ -99,4 +114,41 @@ mod tests {
99114 vec![ false , false , true ]
100115 ) ;
101116 }
117+
118+ #[ test]
119+ fn compare_decimal_const_unconvertible_comparison ( ) {
120+ let decimal_dtype = DecimalDType :: new ( 40 , 2 ) ;
121+ let dtype = DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) ;
122+ let lhs = DecimalBytePartsArray :: try_new (
123+ PrimitiveArray :: new ( buffer ! [ 100i32 , 200i32 , 400i32 ] , Validity :: AllValid ) . to_array ( ) ,
124+ vec ! [ ] ,
125+ decimal_dtype,
126+ )
127+ . unwrap ( )
128+ . to_array ( ) ;
129+ // This cannot be converted to a i32.
130+ let rhs = ConstantArray :: new (
131+ Scalar :: new ( dtype, DecimalValue :: I128 ( -9999999999999965304 ) . into ( ) ) ,
132+ lhs. len ( ) ,
133+ ) ;
134+
135+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Eq ) . unwrap ( ) ;
136+
137+ assert_eq ! (
138+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
139+ vec![ false , false , false ]
140+ ) ;
141+
142+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Gt ) . unwrap ( ) ;
143+ assert_eq ! (
144+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
145+ vec![ false , false , false ]
146+ ) ;
147+
148+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Lt ) . unwrap ( ) ;
149+ assert_eq ! (
150+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
151+ vec![ true , true , true ]
152+ ) ;
153+ }
102154}
0 commit comments