@@ -3,6 +3,7 @@ use std::fmt::{Display, Formatter};
33
44use arrow_buffer:: BooleanBuffer ;
55use arrow_ord:: cmp;
6+ use arrow_schema:: DataType ;
67use vortex_dtype:: { DType , NativePType , Nullability } ;
78use vortex_error:: { VortexExpect , VortexResult , vortex_bail} ;
89use vortex_scalar:: Scalar ;
@@ -194,8 +195,8 @@ fn arrow_compare(
194195 operator : Operator ,
195196) -> VortexResult < ArrayRef > {
196197 let nullable = left. dtype ( ) . is_nullable ( ) || right. dtype ( ) . is_nullable ( ) ;
197- let lhs = Datum :: try_new ( left. to_array ( ) ) ?;
198- let rhs = Datum :: try_new ( right. to_array ( ) ) ?;
198+ let lhs = datum_for_cmp ( left) ?;
199+ let rhs = datum_for_cmp ( right) ?;
199200
200201 let array = match operator {
201202 Operator :: Eq => cmp:: eq ( & lhs, & rhs) ?,
@@ -250,14 +251,26 @@ pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
250251 }
251252}
252253
254+ // Make sure both of the arrays end up with the same arrow data type
255+ fn datum_for_cmp ( array : & dyn Array ) -> VortexResult < Datum > {
256+ if matches ! ( array. dtype( ) , DType :: Utf8 ( _) ) {
257+ Datum :: with_target_datatype ( array, & DataType :: Utf8View )
258+ } else if matches ! ( array. dtype( ) , DType :: Binary ( _) ) {
259+ Datum :: with_target_datatype ( array, & DataType :: BinaryView )
260+ } else {
261+ Datum :: try_new ( array)
262+ }
263+ }
264+
253265#[ cfg( test) ]
254266mod tests {
255267 use arrow_buffer:: BooleanBuffer ;
256268 use itertools:: Itertools ;
269+ use rstest:: rstest;
257270
258271 use super :: * ;
259272 use crate :: ToCanonical ;
260- use crate :: arrays:: { BoolArray , ConstantArray } ;
273+ use crate :: arrays:: { BoolArray , ConstantArray , VarBinArray , VarBinViewArray } ;
261274 use crate :: validity:: Validity ;
262275
263276 fn to_int_indices ( indices_bits : BoolArray ) -> Vec < u64 > {
@@ -337,7 +350,7 @@ mod tests {
337350 assert_eq ! ( compare. len( ) , 10 ) ;
338351 }
339352
340- #[ rstest:: rstest ]
353+ #[ rstest]
341354 #[ case( Operator :: Eq , vec![ false , false , false , true ] ) ]
342355 #[ case( Operator :: NotEq , vec![ true , true , true , false ] ) ]
343356 #[ case( Operator :: Gt , vec![ true , true , true , false ] ) ]
@@ -350,4 +363,17 @@ mod tests {
350363 let output = compare_lengths_to_empty ( lengths. iter ( ) . copied ( ) , op) ;
351364 assert_eq ! ( Vec :: from_iter( output. iter( ) ) , expected) ;
352365 }
366+
367+ #[ rstest]
368+ #[ case( VarBinArray :: from( vec![ "a" , "b" ] ) . into_array( ) , VarBinViewArray :: from_iter_str( [ "a" , "b" ] ) . into_array( ) ) ]
369+ #[ case( VarBinViewArray :: from_iter_str( [ "a" , "b" ] ) . into_array( ) , VarBinArray :: from( vec![ "a" , "b" ] ) . into_array( ) ) ]
370+ #[ case( VarBinArray :: from( vec![ "a" . as_bytes( ) , "b" . as_bytes( ) ] ) . into_array( ) , VarBinViewArray :: from_iter_bin( [ "a" . as_bytes( ) , "b" . as_bytes( ) ] ) . into_array( ) ) ]
371+ #[ case( VarBinViewArray :: from_iter_bin( [ "a" . as_bytes( ) , "b" . as_bytes( ) ] ) . into_array( ) , VarBinArray :: from( vec![ "a" . as_bytes( ) , "b" . as_bytes( ) ] ) . into_array( ) ) ]
372+ fn arrow_compare_different_encodings ( #[ case] left : ArrayRef , #[ case] right : ArrayRef ) {
373+ let res = arrow_compare ( & left, & right, Operator :: Eq ) . unwrap ( ) ;
374+ assert_eq ! (
375+ res. to_bool( ) . unwrap( ) . boolean_buffer( ) . count_set_bits( ) ,
376+ left. len( )
377+ ) ;
378+ }
353379}
0 commit comments