11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use std:: fmt:: Debug ;
5- use std:: ops:: Deref ;
6-
74use vortex_array:: accessor:: ArrayAccessor ;
8- use vortex_array:: arrays:: BoolArray ;
5+ use vortex_array:: arrays:: { BoolArray , NativeValue } ;
96use vortex_array:: compute:: { Operator , scalar_cmp} ;
107use vortex_array:: validity:: Validity ;
118use vortex_array:: { Array , ArrayRef , IntoArray , ToCanonical } ;
129use vortex_buffer:: BitBuffer ;
13- use vortex_dtype:: {
14- DType , NativeDecimalType , NativePType , match_each_decimal_value_type, match_each_native_ptype,
15- } ;
10+ use vortex_dtype:: { DType , Nullability , match_each_decimal_value_type, match_each_native_ptype} ;
1611use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
1712use vortex_scalar:: Scalar ;
1813
@@ -29,6 +24,8 @@ pub fn compare_canonical_array(
2924 . into_array ( ) ) ;
3025 }
3126
27+ let result_nullability = array. dtype ( ) . nullability ( ) | value. dtype ( ) . nullability ( ) ;
28+
3229 match array. dtype ( ) {
3330 DType :: Bool ( _) => {
3431 let bool = value
@@ -44,6 +41,7 @@ pub fn compare_canonical_array(
4441 . map ( |( b, v) | v. then_some ( b) ) ,
4542 bool,
4643 operator,
44+ result_nullability,
4745 ) )
4846 }
4947 DType :: Primitive ( p, _) => {
@@ -53,15 +51,16 @@ pub fn compare_canonical_array(
5351 let pval = primitive
5452 . typed_value:: <P >( )
5553 . vortex_expect( "nulls handled before" ) ;
56- Ok ( compare_native_ptype (
54+ Ok ( compare_to (
5755 primitive_array
5856 . as_slice:: <P >( )
5957 . iter( )
6058 . copied( )
6159 . zip( array. validity_mask( ) . to_bit_buffer( ) . iter( ) )
62- . map( |( b, v) | v. then_some( b ) ) ,
63- pval,
60+ . map( |( b, v) | v. then_some( NativeValue ( b ) ) ) ,
61+ NativeValue ( pval) ,
6462 operator,
63+ result_nullability,
6564 ) )
6665 } )
6766 }
@@ -75,14 +74,15 @@ pub fn compare_canonical_array(
7574 . cast:: <D >( )
7675 . ok_or_else( || vortex_err!( "todo: handle upcast of decimal array" ) ) ?;
7776 let buf = decimal_array. buffer:: <D >( ) ;
78- Ok ( compare_native_decimal_type (
77+ Ok ( compare_to (
7978 buf. as_slice( )
8079 . iter( )
8180 . copied( )
8281 . zip( array. validity_mask( ) . to_bit_buffer( ) . iter( ) )
8382 . map( |( b, v) | v. then_some( b) ) ,
8483 dval,
8584 operator,
85+ result_nullability,
8686 ) )
8787 } )
8888 }
@@ -93,8 +93,9 @@ pub fn compare_canonical_array(
9393 . vortex_expect ( "nulls handled before" ) ;
9494 compare_to (
9595 iter. map ( |v| v. map ( |b| unsafe { str:: from_utf8_unchecked ( b) } ) ) ,
96- utf8_value. deref ( ) ,
96+ & utf8_value,
9797 operator,
98+ result_nullability,
9899 )
99100 } ) ,
100101 DType :: Binary ( _) => array. to_varbinview ( ) . with_iterator ( |iter| {
@@ -106,8 +107,9 @@ pub fn compare_canonical_array(
106107 // Don't understand the lifetime problem here but identity map makes it go away
107108 #[ allow( clippy:: map_identity) ]
108109 iter. map ( |v| v) ,
109- binary_value. deref ( ) ,
110+ & binary_value,
110111 operator,
112+ result_nullability,
111113 )
112114 } ) ,
113115 DType :: Struct ( ..) | DType :: List ( ..) | DType :: FixedSizeList ( ..) => {
@@ -125,56 +127,29 @@ pub fn compare_canonical_array(
125127 }
126128}
127129
128- fn compare_to < T : PartialOrd + PartialEq + Debug > (
129- values : impl Iterator < Item = Option < T > > ,
130- cmp_value : T ,
131- operator : Operator ,
132- ) -> ArrayRef {
133- BoolArray :: from_iter ( values. map ( |val| {
134- val. map ( |v| match operator {
135- Operator :: Eq => v == cmp_value,
136- Operator :: NotEq => v != cmp_value,
137- Operator :: Gt => v > cmp_value,
138- Operator :: Gte => v >= cmp_value,
139- Operator :: Lt => v < cmp_value,
140- Operator :: Lte => v <= cmp_value,
141- } )
142- } ) )
143- . into_array ( )
144- }
145-
146- fn compare_native_ptype < T : NativePType > (
130+ fn compare_to < T : PartialOrd > (
147131 values : impl Iterator < Item = Option < T > > ,
148132 cmp_value : T ,
149133 operator : Operator ,
134+ nullability : Nullability ,
150135) -> ArrayRef {
151- BoolArray :: from_iter ( values. map ( |val| {
152- val. map ( |v| match operator {
153- Operator :: Eq => v. is_eq ( cmp_value) ,
154- Operator :: NotEq => !v. is_eq ( cmp_value) ,
155- Operator :: Gt => v. is_gt ( cmp_value) ,
156- Operator :: Gte => v. is_ge ( cmp_value) ,
157- Operator :: Lt => v. is_lt ( cmp_value) ,
158- Operator :: Lte => v. is_le ( cmp_value) ,
159- } )
160- } ) )
161- . into_array ( )
162- }
136+ let eval_fn = |v| match operator {
137+ Operator :: Eq => v == cmp_value,
138+ Operator :: NotEq => v != cmp_value,
139+ Operator :: Gt => v > cmp_value,
140+ Operator :: Gte => v >= cmp_value,
141+ Operator :: Lt => v < cmp_value,
142+ Operator :: Lte => v <= cmp_value,
143+ } ;
163144
164- fn compare_native_decimal_type < D : NativeDecimalType > (
165- values : impl Iterator < Item = Option < D > > ,
166- cmp_value : D ,
167- operator : Operator ,
168- ) -> ArrayRef {
169- BoolArray :: from_iter ( values. map ( |val| {
170- val. map ( |v| match operator {
171- Operator :: Eq => v == cmp_value,
172- Operator :: NotEq => v != cmp_value,
173- Operator :: Gt => v > cmp_value,
174- Operator :: Gte => v >= cmp_value,
175- Operator :: Lt => v < cmp_value,
176- Operator :: Lte => v <= cmp_value,
177- } )
178- } ) )
179- . into_array ( )
145+ if !nullability. is_nullable ( ) {
146+ BoolArray :: from_iter (
147+ values
148+ . map ( |val| val. vortex_expect ( "non nullable" ) )
149+ . map ( eval_fn) ,
150+ )
151+ . into_array ( )
152+ } else {
153+ BoolArray :: from_iter ( values. map ( |val| val. map ( eval_fn) ) ) . into_array ( )
154+ }
180155}
0 commit comments