@@ -10,6 +10,11 @@ use crate::{ZstdArray, ZstdVTable};
1010
1111impl CastKernel for ZstdVTable {
1212 fn cast ( & self , array : & ZstdArray , dtype : & DType ) -> VortexResult < Option < ArrayRef > > {
13+ if !dtype. is_nullable ( ) || !array. all_valid ( ) {
14+ // We cannot cast to non-nullable since the validity containing nulls is used to decode
15+ // the ZSTD array, this would require rewriting tables.
16+ return Ok ( None ) ;
17+ }
1318 // ZstdArray is a general-purpose compression encoding using Zstandard compression.
1419 // It can handle nullability changes without decompression by updating the validity
1520 // bitmap, but type changes require decompression since the compressed data is
@@ -48,6 +53,7 @@ mod tests {
4853 use vortex_array:: arrays:: PrimitiveArray ;
4954 use vortex_array:: compute:: cast;
5055 use vortex_array:: compute:: conformance:: cast:: test_cast_conformance;
56+ use vortex_array:: validity:: Validity ;
5157 use vortex_array:: { ToCanonical , assert_arrays_eq} ;
5258 use vortex_buffer:: Buffer ;
5359 use vortex_dtype:: { DType , Nullability , PType } ;
@@ -58,7 +64,7 @@ mod tests {
5864 fn test_cast_zstd_i32_to_i64 ( ) {
5965 let values = PrimitiveArray :: new (
6066 Buffer :: copy_from ( vec ! [ 1i32 , 2 , 3 , 4 , 5 ] ) ,
61- vortex_array :: validity :: Validity :: NonNullable ,
67+ Validity :: NonNullable ,
6268 ) ;
6369 let zstd = ZstdArray :: from_primitive ( & values, 0 , 0 ) . unwrap ( ) ;
6470
@@ -80,7 +86,7 @@ mod tests {
8086 fn test_cast_zstd_nullability_change ( ) {
8187 let values = PrimitiveArray :: new (
8288 Buffer :: copy_from ( vec ! [ 10u32 , 20 , 30 , 40 ] ) ,
83- vortex_array :: validity :: Validity :: NonNullable ,
89+ Validity :: NonNullable ,
8490 ) ;
8591 let zstd = ZstdArray :: from_primitive ( & values, 0 , 0 ) . unwrap ( ) ;
8692
@@ -95,22 +101,71 @@ mod tests {
95101 ) ;
96102 }
97103
104+ #[ test]
105+ fn test_cast_sliced_zstd_nullable_to_nonnullable ( ) {
106+ let values = PrimitiveArray :: new (
107+ Buffer :: copy_from ( vec ! [ 10u32 , 20 , 30 , 40 , 50 , 60 ] ) ,
108+ Validity :: from_iter ( [ true , true , true , true , true , true ] ) ,
109+ ) ;
110+ let zstd = ZstdArray :: from_primitive ( & values, 0 , 128 ) . unwrap ( ) ;
111+ let sliced = zstd. slice ( 1 ..5 ) ;
112+ let casted = cast (
113+ sliced. as_ref ( ) ,
114+ & DType :: Primitive ( PType :: U32 , Nullability :: NonNullable ) ,
115+ )
116+ . unwrap ( ) ;
117+ assert_eq ! (
118+ casted. dtype( ) ,
119+ & DType :: Primitive ( PType :: U32 , Nullability :: NonNullable )
120+ ) ;
121+ // Verify the values are correct
122+ let decoded = casted. to_primitive ( ) ;
123+ let u32_values = decoded. as_slice :: < u32 > ( ) ;
124+ assert_eq ! ( u32_values, & [ 20 , 30 , 40 , 50 ] ) ;
125+ }
126+
127+ #[ test]
128+ fn test_cast_sliced_zstd_part_valid_to_nonnullable ( ) {
129+ let values = PrimitiveArray :: from_option_iter ( [
130+ None ,
131+ Some ( 20u32 ) ,
132+ Some ( 30 ) ,
133+ Some ( 40 ) ,
134+ Some ( 50 ) ,
135+ Some ( 60 ) ,
136+ ] ) ;
137+ let zstd = ZstdArray :: from_primitive ( & values, 0 , 128 ) . unwrap ( ) ;
138+ let sliced = zstd. slice ( 1 ..5 ) ;
139+ let casted = cast (
140+ sliced. as_ref ( ) ,
141+ & DType :: Primitive ( PType :: U32 , Nullability :: NonNullable ) ,
142+ )
143+ . unwrap ( ) ;
144+ assert_eq ! (
145+ casted. dtype( ) ,
146+ & DType :: Primitive ( PType :: U32 , Nullability :: NonNullable )
147+ ) ;
148+ let decoded = casted. to_primitive ( ) ;
149+ let expected = PrimitiveArray :: from_iter ( [ 20u32 , 30 , 40 , 50 ] ) ;
150+ assert_arrays_eq ! ( decoded, expected) ;
151+ }
152+
98153 #[ rstest]
99154 #[ case:: i32( PrimitiveArray :: new(
100155 Buffer :: copy_from( vec![ 100i32 , 200 , 300 , 400 , 500 ] ) ,
101- vortex_array :: validity :: Validity :: NonNullable ,
156+ Validity :: NonNullable ,
102157 ) ) ]
103158 #[ case:: f64( PrimitiveArray :: new(
104159 Buffer :: copy_from( vec![ 1.1f64 , 2.2 , 3.3 , 4.4 , 5.5 ] ) ,
105- vortex_array :: validity :: Validity :: NonNullable ,
160+ Validity :: NonNullable ,
106161 ) ) ]
107162 #[ case:: single( PrimitiveArray :: new(
108163 Buffer :: copy_from( vec![ 42i64 ] ) ,
109- vortex_array :: validity :: Validity :: NonNullable ,
164+ Validity :: NonNullable ,
110165 ) ) ]
111166 #[ case:: large( PrimitiveArray :: new(
112167 Buffer :: copy_from( ( 0 ..1000 ) . map( |i| i as u32 ) . collect:: <Vec <_>>( ) ) ,
113- vortex_array :: validity :: Validity :: NonNullable ,
168+ Validity :: NonNullable ,
114169 ) ) ]
115170 fn test_cast_zstd_conformance ( #[ case] values : PrimitiveArray ) {
116171 let zstd = ZstdArray :: from_primitive ( & values, 0 , 0 ) . unwrap ( ) ;
0 commit comments