1+ use vortex_array:: array:: ConstantArray ;
12use vortex_array:: compute:: unary:: { scalar_at, scalar_at_unchecked, ScalarAtFn } ;
23use vortex_array:: compute:: {
34 compare, filter, slice, take, ArrayCompute , FilterFn , FilterMask , MaybeCompareFn , Operator ,
45 SliceFn , TakeFn , TakeOptions ,
56} ;
6- use vortex_array:: stats:: { ArrayStatistics , Stat } ;
77use vortex_array:: { ArrayData , IntoArrayData } ;
88use vortex_error:: { VortexExpect , VortexResult } ;
99use vortex_scalar:: Scalar ;
@@ -39,17 +39,16 @@ impl MaybeCompareFn for DictArray {
3939 operator : Operator ,
4040 ) -> Option < VortexResult < ArrayData > > {
4141 // If the RHS is constant, then we just need to compare against our encoded values.
42- if other
43- . statistics ( )
44- . get_as :: < bool > ( Stat :: IsConstant )
45- . unwrap_or_default ( )
46- {
42+ if let Some ( const_scalar) = other. as_constant ( ) {
4743 return Some (
4844 // Ensure the other is the same length as the dictionary
49- slice ( other, 0 , self . values ( ) . len ( ) )
50- . and_then ( |other| compare ( self . values ( ) , other, operator) )
51- . and_then ( |values| Self :: try_new ( self . codes ( ) , values) )
52- . map ( |a| a. into_array ( ) ) ,
45+ compare (
46+ self . values ( ) ,
47+ ConstantArray :: new ( const_scalar, self . values ( ) . len ( ) ) ,
48+ operator,
49+ )
50+ . and_then ( |values| Self :: try_new ( self . codes ( ) , values) )
51+ . map ( |a| a. into_array ( ) ) ,
5352 ) ;
5453 }
5554
@@ -102,14 +101,19 @@ impl SliceFn for DictArray {
102101#[ cfg( test) ]
103102mod test {
104103 use vortex_array:: accessor:: ArrayAccessor ;
105- use vortex_array:: array:: { PrimitiveArray , VarBinViewArray } ;
104+ use vortex_array:: array:: { ConstantArray , PrimitiveArray , VarBinViewArray } ;
105+ use vortex_array:: compute:: unary:: scalar_at;
106+ use vortex_array:: compute:: { compare, slice, Operator } ;
106107 use vortex_array:: { IntoArrayData , IntoArrayVariant , ToArrayData } ;
107108 use vortex_dtype:: { DType , Nullability } ;
109+ use vortex_scalar:: Scalar ;
108110
109- use crate :: { dict_encode_typed_primitive, dict_encode_varbinview, DictArray } ;
111+ use crate :: {
112+ dict_encode_primitive, dict_encode_typed_primitive, dict_encode_varbinview, DictArray ,
113+ } ;
110114
111115 #[ test]
112- fn flatten_nullable_primitive ( ) {
116+ fn canonicalise_nullable_primitive ( ) {
113117 let reference = PrimitiveArray :: from_nullable_vec ( vec ! [
114118 Some ( 42 ) ,
115119 Some ( -9 ) ,
@@ -125,7 +129,7 @@ mod test {
125129 }
126130
127131 #[ test]
128- fn flatten_nullable_varbin ( ) {
132+ fn canonicalise_nullable_varbin ( ) {
129133 let reference = VarBinViewArray :: from_iter (
130134 vec ! [ Some ( "a" ) , Some ( "b" ) , None , Some ( "a" ) , None , Some ( "b" ) ] ,
131135 DType :: Utf8 ( Nullability :: Nullable ) ,
@@ -147,4 +151,32 @@ mod test {
147151 . unwrap( ) ,
148152 ) ;
149153 }
154+
155+ #[ test]
156+ fn compare_sliced_dict ( ) {
157+ let reference = PrimitiveArray :: from_nullable_vec ( vec ! [
158+ Some ( 42 ) ,
159+ Some ( -9 ) ,
160+ None ,
161+ Some ( 42 ) ,
162+ Some ( 1 ) ,
163+ Some ( 5 ) ,
164+ ] ) ;
165+ let ( codes, values) = dict_encode_primitive ( & reference) ;
166+ let dict = DictArray :: try_new ( codes. into_array ( ) , values. into_array ( ) ) . unwrap ( ) ;
167+ let sliced = slice ( dict, 1 , 4 ) . unwrap ( ) ;
168+ let compared = compare ( sliced, ConstantArray :: new ( 42 , 3 ) , Operator :: Eq ) . unwrap ( ) ;
169+ assert_eq ! (
170+ scalar_at( & compared, 0 ) . unwrap( ) ,
171+ Scalar :: bool ( false , Nullability :: Nullable )
172+ ) ;
173+ assert_eq ! (
174+ scalar_at( & compared, 1 ) . unwrap( ) ,
175+ Scalar :: null( DType :: Bool ( Nullability :: Nullable ) )
176+ ) ;
177+ assert_eq ! (
178+ scalar_at( compared, 2 ) . unwrap( ) ,
179+ Scalar :: bool ( true , Nullability :: Nullable )
180+ ) ;
181+ }
150182}
0 commit comments