@@ -11,7 +11,7 @@ use vortex_array::{
1111 encoding_ids, impl_encoding, Array , Canonical , IntoArray , IntoArrayVariant , IntoCanonical ,
1212 SerdeMetadata ,
1313} ;
14- use vortex_dtype:: { match_each_integer_ptype, DType , Nullability , PType } ;
14+ use vortex_dtype:: { match_each_integer_ptype, DType , PType } ;
1515use vortex_error:: { vortex_bail, VortexExpect as _, VortexResult } ;
1616use vortex_mask:: { AllOr , Mask } ;
1717
@@ -22,85 +22,41 @@ impl_encoding!(
2222 SerdeMetadata <DictMetadata >
2323) ;
2424
25- #[ derive(
26- Copy ,
27- Clone ,
28- Debug ,
29- Serialize ,
30- Deserialize ,
31- rkyv:: Archive ,
32- rkyv:: Portable ,
33- rkyv:: Serialize ,
34- rkyv:: Deserialize ,
35- rkyv:: bytecheck:: CheckBytes ,
36- ) ]
37- #[ rkyv( as = DictNullability ) ]
38- #[ bytecheck( crate = rkyv:: bytecheck) ]
39- #[ repr( u8 ) ]
40- enum DictNullability {
41- NonNullable ,
42- NullableCodes ,
43- NullableValues ,
44- BothNullable ,
45- }
46-
47- impl DictNullability {
48- fn from_dtypes ( codes_dtype : & DType , values_dtype : & DType ) -> Self {
49- match ( codes_dtype. is_nullable ( ) , values_dtype. is_nullable ( ) ) {
50- ( true , true ) => Self :: BothNullable ,
51- ( true , false ) => Self :: NullableCodes ,
52- ( false , true ) => Self :: NullableValues ,
53- ( false , false ) => Self :: NonNullable ,
54- }
55- }
56-
57- fn codes_nullability ( & self ) -> Nullability {
58- match self {
59- DictNullability :: NonNullable => Nullability :: NonNullable ,
60- DictNullability :: NullableCodes => Nullability :: Nullable ,
61- DictNullability :: NullableValues => Nullability :: NonNullable ,
62- DictNullability :: BothNullable => Nullability :: Nullable ,
63- }
64- }
65-
66- fn values_nullability ( & self ) -> Nullability {
67- match self {
68- DictNullability :: NonNullable => Nullability :: NonNullable ,
69- DictNullability :: NullableCodes => Nullability :: NonNullable ,
70- DictNullability :: NullableValues => Nullability :: Nullable ,
71- DictNullability :: BothNullable => Nullability :: Nullable ,
72- }
73- }
74- }
75-
7625#[ derive( Debug , Clone , Serialize , Deserialize ) ]
7726pub struct DictMetadata {
7827 codes_ptype : PType ,
7928 values_len : usize , // TODO(ngates): make this a u32
80- dict_nullability : DictNullability ,
8129}
8230
8331impl DictArray {
84- pub fn try_new ( codes : Array , values : Array ) -> VortexResult < Self > {
32+ pub fn try_new ( mut codes : Array , values : Array ) -> VortexResult < Self > {
8533 if !codes. dtype ( ) . is_unsigned_int ( ) {
8634 vortex_bail ! ( MismatchedTypes : "unsigned int" , codes. dtype( ) ) ;
8735 }
8836
89- let dtype = if codes. dtype ( ) . is_nullable ( ) {
90- values. dtype ( ) . as_nullable ( )
37+ let dtype = values. dtype ( ) ;
38+ if dtype. is_nullable ( ) {
39+ // If the values are nullable, we force codes to be nullable as well.
40+ codes = try_cast ( & codes, & codes. dtype ( ) . as_nullable ( ) ) ?;
9141 } else {
92- values. dtype ( ) . clone ( )
93- } ;
94- let dict_nullability = DictNullability :: from_dtypes ( codes. dtype ( ) , values. dtype ( ) ) ;
42+ // If the values are non-nullable, we assert the codes are non-nullable as well.
43+ if codes. dtype ( ) . is_nullable ( ) {
44+ vortex_bail ! ( "Cannot have nullable codes for non-nullable dict array" ) ;
45+ }
46+ }
47+ assert_eq ! (
48+ codes. dtype( ) . nullability( ) ,
49+ values. dtype( ) . nullability( ) ,
50+ "Mismatched nullability between codes and values"
51+ ) ;
9552
9653 Self :: try_from_parts (
97- dtype,
54+ dtype. clone ( ) ,
9855 codes. len ( ) ,
9956 SerdeMetadata ( DictMetadata {
10057 codes_ptype : PType :: try_from ( codes. dtype ( ) )
10158 . vortex_expect ( "codes dtype must be uint" ) ,
10259 values_len : values. len ( ) ,
103- dict_nullability,
10460 } ) ,
10561 None ,
10662 Some ( [ codes, values] . into ( ) ) ,
@@ -113,10 +69,7 @@ impl DictArray {
11369 self . as_ref ( )
11470 . child (
11571 0 ,
116- & DType :: Primitive (
117- self . metadata ( ) . codes_ptype ,
118- self . metadata ( ) . dict_nullability . codes_nullability ( ) ,
119- ) ,
72+ & DType :: Primitive ( self . metadata ( ) . codes_ptype , self . dtype ( ) . nullability ( ) ) ,
12073 self . len ( ) ,
12174 )
12275 . vortex_expect ( "DictArray is missing its codes child array" )
@@ -125,13 +78,7 @@ impl DictArray {
12578 #[ inline]
12679 pub fn values ( & self ) -> Array {
12780 self . as_ref ( )
128- . child (
129- 1 ,
130- & self
131- . dtype ( )
132- . with_nullability ( self . metadata ( ) . dict_nullability . values_nullability ( ) ) ,
133- self . metadata ( ) . values_len ,
134- )
81+ . child ( 1 , self . dtype ( ) , self . metadata ( ) . values_len )
13582 . vortex_expect ( "DictArray is missing its values child array" )
13683 }
13784}
@@ -147,10 +94,10 @@ impl CanonicalVTable<DictArray> for DictEncoding {
14794 // copies of the view pointers.
14895 DType :: Utf8 ( _) | DType :: Binary ( _) => {
14996 let canonical_values: Array = array. values ( ) . into_canonical ( ) ?. into_array ( ) ;
150- try_cast ( take ( canonical_values, array. codes ( ) ) ? , array . dtype ( ) ) ?. into_canonical ( )
97+ take ( canonical_values, array. codes ( ) ) ?. into_canonical ( )
15198 }
15299 // Non-string case: take and then canonicalize
153- _ => try_cast ( take ( array. values ( ) , array. codes ( ) ) ? , array . dtype ( ) ) ?. into_canonical ( ) ,
100+ _ => take ( array. values ( ) , array. codes ( ) ) ?. into_canonical ( ) ,
154101 }
155102 }
156103}
@@ -182,6 +129,14 @@ impl ValidityVTable<DictArray> for DictEncoding {
182129 Ok ( array. codes ( ) . all_valid ( ) ? && array. values ( ) . all_valid ( ) ?)
183130 }
184131
132+ fn all_invalid ( & self , array : & DictArray ) -> VortexResult < bool > {
133+ if !array. dtype ( ) . is_nullable ( ) {
134+ return Ok ( false ) ;
135+ }
136+
137+ Ok ( array. codes ( ) . all_invalid ( ) ? || array. values ( ) . all_invalid ( ) ?)
138+ }
139+
185140 fn validity_mask ( & self , array : & DictArray ) -> VortexResult < Mask > {
186141 let codes_validity = array. codes ( ) . validity_mask ( ) ?;
187142 match codes_validity. boolean_buffer ( ) {
@@ -231,7 +186,6 @@ mod test {
231186 use vortex_error:: vortex_panic;
232187 use vortex_mask:: AllOr ;
233188
234- use crate :: array:: DictNullability :: BothNullable ;
235189 use crate :: { DictArray , DictMetadata } ;
236190
237191 #[ cfg_attr( miri, ignore) ]
@@ -242,7 +196,6 @@ mod test {
242196 SerdeMetadata ( DictMetadata {
243197 codes_ptype : PType :: U64 ,
244198 values_len : usize:: MAX ,
245- dict_nullability : BothNullable ,
246199 } ) ,
247200 ) ;
248201 }
@@ -255,7 +208,7 @@ mod test {
255208 Validity :: from ( BooleanBuffer :: from ( vec ! [ true , false , true , false , true ] ) ) ,
256209 )
257210 . into_array ( ) ,
258- buffer ! [ 3 , 6 , 9 ] . into_array ( ) ,
211+ PrimitiveArray :: new ( buffer ! [ 3 , 6 , 9 ] , Validity :: AllValid ) . into_array ( ) ,
259212 )
260213 . unwrap ( ) ;
261214 let mask = dict. validity_mask ( ) . unwrap ( ) ;
0 commit comments