22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use std:: fmt:: Debug ;
5- use std:: ops:: Deref ;
65
76use vortex_array:: accessor:: ArrayAccessor ;
87use vortex_array:: arrays:: BoolArray ;
@@ -11,7 +10,7 @@ use vortex_array::validity::Validity;
1110use vortex_array:: { Array , ArrayRef , IntoArray , ToCanonical } ;
1211use vortex_buffer:: BitBuffer ;
1312use vortex_dtype:: {
14- DType , NativeDecimalType , NativePType , match_each_decimal_value_type, match_each_native_ptype,
13+ DType , NativePType , Nullability , match_each_decimal_value_type, match_each_native_ptype,
1514} ;
1615use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
1716use vortex_scalar:: Scalar ;
@@ -29,6 +28,8 @@ pub fn compare_canonical_array(
2928 . into_array ( ) ) ;
3029 }
3130
31+ let result_nullability = array. dtype ( ) . nullability ( ) | value. dtype ( ) . nullability ( ) ;
32+
3233 match array. dtype ( ) {
3334 DType :: Bool ( _) => {
3435 let bool = value
@@ -44,6 +45,7 @@ pub fn compare_canonical_array(
4445 . map ( |( b, v) | v. then_some ( b) ) ,
4546 bool,
4647 operator,
48+ result_nullability,
4749 ) )
4850 }
4951 DType :: Primitive ( p, _) => {
@@ -62,6 +64,7 @@ pub fn compare_canonical_array(
6264 . map( |( b, v) | v. then_some( b) ) ,
6365 pval,
6466 operator,
67+ result_nullability,
6568 ) )
6669 } )
6770 }
@@ -75,14 +78,15 @@ pub fn compare_canonical_array(
7578 . cast:: <D >( )
7679 . ok_or_else( || vortex_err!( "todo: handle upcast of decimal array" ) ) ?;
7780 let buf = decimal_array. buffer:: <D >( ) ;
78- Ok ( compare_native_decimal_type (
81+ Ok ( compare_to (
7982 buf. as_slice( )
8083 . iter( )
8184 . copied( )
8285 . zip( array. validity_mask( ) . to_bit_buffer( ) . iter( ) )
8386 . map( |( b, v) | v. then_some( b) ) ,
8487 dval,
8588 operator,
89+ result_nullability,
8690 ) )
8791 } )
8892 }
@@ -93,8 +97,9 @@ pub fn compare_canonical_array(
9397 . vortex_expect ( "nulls handled before" ) ;
9498 compare_to (
9599 iter. map ( |v| v. map ( |b| unsafe { str:: from_utf8_unchecked ( b) } ) ) ,
96- utf8_value. deref ( ) ,
100+ & utf8_value,
97101 operator,
102+ result_nullability,
98103 )
99104 } ) ,
100105 DType :: Binary ( _) => array. to_varbinview ( ) . with_iterator ( |iter| {
@@ -106,8 +111,9 @@ pub fn compare_canonical_array(
106111 // Don't understand the lifetime problem here but identity map makes it go away
107112 #[ allow( clippy:: map_identity) ]
108113 iter. map ( |v| v) ,
109- binary_value. deref ( ) ,
114+ & binary_value,
110115 operator,
116+ result_nullability,
111117 )
112118 } ) ,
113119 DType :: Struct ( ..) | DType :: List ( ..) | DType :: FixedSizeList ( ..) => {
@@ -125,56 +131,48 @@ pub fn compare_canonical_array(
125131 }
126132}
127133
134+ #[ allow( clippy:: unwrap_used) ]
128135fn compare_to < T : PartialOrd + PartialEq + Debug > (
129136 values : impl Iterator < Item = Option < T > > ,
130137 cmp_value : T ,
131138 operator : Operator ,
139+ nullability : Nullability ,
132140) -> ArrayRef {
133- BoolArray :: from_iter ( values. map ( |val| {
134- val. map ( |v| match operator {
135- Operator :: Eq => v == cmp_value,
136- Operator :: NotEq => v != cmp_value,
137- Operator :: Gt => v > cmp_value,
138- Operator :: Gte => v >= cmp_value,
139- Operator :: Lt => v < cmp_value,
140- Operator :: Lte => v <= cmp_value,
141- } )
142- } ) )
143- . into_array ( )
141+ let eval_fn = |v| match operator {
142+ Operator :: Eq => v == cmp_value,
143+ Operator :: NotEq => v != cmp_value,
144+ Operator :: Gt => v > cmp_value,
145+ Operator :: Gte => v >= cmp_value,
146+ Operator :: Lt => v < cmp_value,
147+ Operator :: Lte => v <= cmp_value,
148+ } ;
149+
150+ if !nullability. is_nullable ( ) {
151+ BoolArray :: from_iter ( values. map ( |val| val. unwrap ( ) ) . map ( eval_fn) ) . into_array ( )
152+ } else {
153+ BoolArray :: from_iter ( values. map ( |val| val. map ( eval_fn) ) ) . into_array ( )
154+ }
144155}
145156
157+ #[ allow( clippy:: unwrap_used) ]
146158fn compare_native_ptype < T : NativePType > (
147159 values : impl Iterator < Item = Option < T > > ,
148160 cmp_value : T ,
149161 operator : Operator ,
162+ nullability : Nullability ,
150163) -> ArrayRef {
151- BoolArray :: from_iter ( values. map ( |val| {
152- val. map ( |v| match operator {
153- Operator :: Eq => v. is_eq ( cmp_value) ,
154- Operator :: NotEq => !v. is_eq ( cmp_value) ,
155- Operator :: Gt => v. is_gt ( cmp_value) ,
156- Operator :: Gte => v. is_ge ( cmp_value) ,
157- Operator :: Lt => v. is_lt ( cmp_value) ,
158- Operator :: Lte => v. is_le ( cmp_value) ,
159- } )
160- } ) )
161- . into_array ( )
162- }
164+ let eval_fn = |v : T | match operator {
165+ Operator :: Eq => v. is_eq ( cmp_value) ,
166+ Operator :: NotEq => !v. is_eq ( cmp_value) ,
167+ Operator :: Gt => v. is_gt ( cmp_value) ,
168+ Operator :: Gte => v. is_ge ( cmp_value) ,
169+ Operator :: Lt => v. is_lt ( cmp_value) ,
170+ Operator :: Lte => v. is_le ( cmp_value) ,
171+ } ;
163172
164- fn compare_native_decimal_type < D : NativeDecimalType > (
165- values : impl Iterator < Item = Option < D > > ,
166- cmp_value : D ,
167- operator : Operator ,
168- ) -> ArrayRef {
169- BoolArray :: from_iter ( values. map ( |val| {
170- val. map ( |v| match operator {
171- Operator :: Eq => v == cmp_value,
172- Operator :: NotEq => v != cmp_value,
173- Operator :: Gt => v > cmp_value,
174- Operator :: Gte => v >= cmp_value,
175- Operator :: Lt => v < cmp_value,
176- Operator :: Lte => v <= cmp_value,
177- } )
178- } ) )
179- . into_array ( )
173+ if !nullability. is_nullable ( ) {
174+ BoolArray :: from_iter ( values. map ( |val| val. unwrap ( ) ) . map ( eval_fn) ) . into_array ( )
175+ } else {
176+ BoolArray :: from_iter ( values. map ( |val| val. map ( eval_fn) ) ) . into_array ( )
177+ }
180178}
0 commit comments