1- use num_traits:: NumCast ;
1+ use Sign :: Negative ;
2+ use num_traits:: { Bounded , NumCast } ;
23use vortex_array:: arrays:: ConstantArray ;
34use vortex_array:: compute:: { CompareKernel , CompareKernelAdapter , Operator , compare} ;
45use vortex_array:: { Array , ArrayRef , register_kernel} ;
56use vortex_dtype:: { NativePType , Nullability , PType , match_each_integer_ptype} ;
67use vortex_error:: { VortexExpect , VortexResult } ;
7- use vortex_scalar:: { DecimalValue , Scalar , ScalarValue , match_each_decimal_value} ;
8+ use vortex_scalar:: { DecimalValue , Scalar , ScalarValue , ToPrimitive , match_each_decimal_value} ;
89
910use crate :: DecimalBytePartsVTable ;
11+ use crate :: decimal_byte_parts:: compute:: compare:: Sign :: Positive ;
1012
1113impl CompareKernel for DecimalBytePartsVTable {
1214 fn compare (
@@ -30,49 +32,77 @@ impl CompareKernel for DecimalBytePartsVTable {
3032 . as_decimal ( )
3133 . decimal_value ( )
3234 . vortex_expect ( "checked for null in entry func" ) ;
33- let Some ( encoded_scalar) =
34- decimal_value_wrapper_to_primitive ( rhs_decimal, lhs. msp . as_primitive_typed ( ) . ptype ( ) )
35- . map ( |value| Scalar :: new ( scalar_type. clone ( ) , value) )
36- else {
35+
36+ match decimal_value_wrapper_to_primitive ( rhs_decimal, lhs. msp . as_primitive_typed ( ) . ptype ( ) )
37+ . map ( |value| Scalar :: new ( scalar_type. clone ( ) , value) )
38+ {
39+ Ok ( encoded_scalar) => {
40+ let encoded_const = ConstantArray :: new ( encoded_scalar, rhs. len ( ) ) ;
41+ compare ( & lhs. msp , & encoded_const. to_array ( ) , operator) . map ( Some )
42+ }
3743 // here the scalar value is bigger than the msp type.
3844 // TODO(joe): fixme, when allowing lsp values.
39- return Ok ( Some (
40- ConstantArray :: new ( unconvertible_value ( operator, nullability) , lhs. len ( ) )
45+ Err ( sign ) => Ok ( Some (
46+ ConstantArray :: new ( unconvertible_value ( sign , operator, nullability) , lhs. len ( ) )
4147 . to_array ( ) ,
42- ) ) ;
43- } ;
44- let encoded_const = ConstantArray :: new ( encoded_scalar, rhs. len ( ) ) ;
45- compare ( & lhs. msp , & encoded_const. to_array ( ) , operator) . map ( Some )
48+ ) ) ,
49+ }
4650 }
4751}
4852
49- fn unconvertible_value ( operator : Operator , nullability : Nullability ) -> Scalar {
50- // v op unconvertible where unconvertible > v_max
53+ // Used to represent the overflow direction when trying to
54+ // convert into the scalar type.
55+ enum Sign {
56+ Positive ,
57+ Negative ,
58+ }
59+
60+ fn unconvertible_value ( sign : Sign , operator : Operator , nullability : Nullability ) -> Scalar {
5161 match operator {
52- // v is never eq or gt/gte
53- Operator :: Eq | Operator :: Gt | Operator :: Gte => Scalar :: bool ( false , nullability) ,
54- // v is always eq or gt/gte
55- Operator :: NotEq | Operator :: Lt | Operator :: Lte => Scalar :: bool ( true , nullability) ,
62+ Operator :: Eq => Scalar :: bool ( false , nullability ) ,
63+ Operator :: NotEq => Scalar :: bool ( true , nullability) ,
64+ Operator :: Gt | Operator :: Gte => Scalar :: bool ( matches ! ( sign , Positive ) , nullability ) ,
65+ Operator :: Lt | Operator :: Lte => Scalar :: bool ( matches ! ( sign , Negative ) , nullability) ,
5666 }
5767}
5868
5969// this value return None is the decimal scalar cannot be cast the ptype.
6070fn decimal_value_wrapper_to_primitive (
6171 decimal_value : DecimalValue ,
6272 ptype : PType ,
63- ) -> Option < ScalarValue > {
73+ ) -> Result < ScalarValue , Sign > {
6474 match_each_integer_ptype ! ( ptype, |P | {
6575 decimal_value_to_primitive:: <P >( decimal_value)
6676 } )
6777}
6878
69- fn decimal_value_to_primitive < P > ( decimal_value : DecimalValue ) -> Option < ScalarValue >
79+ fn decimal_value_to_primitive < P > ( decimal_value : DecimalValue ) -> Result < ScalarValue , Sign >
7080where
71- P : NativePType + NumCast ,
81+ P : NativePType + NumCast + Bounded + ToPrimitive ,
7282 ScalarValue : From < P > ,
7383{
7484 match_each_decimal_value ! ( decimal_value, |decimal_v| {
75- Some ( ScalarValue :: from( <P as NumCast >:: from( decimal_v) ?) )
85+ let Some ( encoded) = <P as NumCast >:: from( decimal_v) else {
86+ let decimal_i256 = decimal_v
87+ . to_i256( )
88+ . vortex_expect( "i256 is big enough for any DecimalValue" ) ;
89+ return if decimal_i256
90+ > P :: max_value( )
91+ . to_i256( )
92+ . vortex_expect( "i256 is big enough for any PType" )
93+ {
94+ Err ( Positive )
95+ } else {
96+ assert!(
97+ decimal_i256
98+ < P :: min_value( )
99+ . to_i256( )
100+ . vortex_expect( "i256 is big enough for any PType" )
101+ ) ;
102+ Err ( Negative )
103+ } ;
104+ } ;
105+ Ok ( ScalarValue :: from( encoded) )
76106 } )
77107}
78108
@@ -128,7 +158,10 @@ mod tests {
128158 . to_array ( ) ;
129159 // This cannot be converted to a i32.
130160 let rhs = ConstantArray :: new (
131- Scalar :: new ( dtype, DecimalValue :: I128 ( -9999999999999965304 ) . into ( ) ) ,
161+ Scalar :: new (
162+ dtype. clone ( ) ,
163+ DecimalValue :: I128 ( -9999999999999965304 ) . into ( ) ,
164+ ) ,
132165 lhs. len ( ) ,
133166 ) ;
134167
@@ -150,5 +183,30 @@ mod tests {
150183 res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
151184 vec![ true , true , true ]
152185 ) ;
186+
187+ // This cannot be converted to a i32.
188+ let rhs = ConstantArray :: new (
189+ Scalar :: new ( dtype, DecimalValue :: I128 ( 9999999999999965304 ) . into ( ) ) ,
190+ lhs. len ( ) ,
191+ ) ;
192+
193+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Eq ) . unwrap ( ) ;
194+
195+ assert_eq ! (
196+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
197+ vec![ false , false , false ]
198+ ) ;
199+
200+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Gt ) . unwrap ( ) ;
201+ assert_eq ! (
202+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
203+ vec![ true , true , true ]
204+ ) ;
205+
206+ let res = compare ( lhs. as_ref ( ) , rhs. as_ref ( ) , Operator :: Lt ) . unwrap ( ) ;
207+ assert_eq ! (
208+ res. to_bool( ) . unwrap( ) . bool_vec( ) . unwrap( ) ,
209+ vec![ false , false , false ]
210+ ) ;
153211 }
154212}
0 commit comments