Skip to content

Commit 5125ce4

Browse files
authored
Fix: casting Zstd array to non nullable requires decompression (#5258)
1 parent 1c7c759 commit 5125ce4

File tree

1 file changed

+61
-6
lines changed

1 file changed

+61
-6
lines changed

encodings/zstd/src/compute/cast.rs

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ use crate::{ZstdArray, ZstdVTable};
1010

1111
impl CastKernel for ZstdVTable {
1212
fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13+
if !dtype.is_nullable() || !array.all_valid() {
14+
// We cannot cast to non-nullable since the validity containing nulls is used to decode
15+
// the ZSTD array, this would require rewriting tables.
16+
return Ok(None);
17+
}
1318
// ZstdArray is a general-purpose compression encoding using Zstandard compression.
1419
// It can handle nullability changes without decompression by updating the validity
1520
// bitmap, but type changes require decompression since the compressed data is
@@ -48,6 +53,7 @@ mod tests {
4853
use vortex_array::arrays::PrimitiveArray;
4954
use vortex_array::compute::cast;
5055
use vortex_array::compute::conformance::cast::test_cast_conformance;
56+
use vortex_array::validity::Validity;
5157
use vortex_array::{ToCanonical, assert_arrays_eq};
5258
use vortex_buffer::Buffer;
5359
use vortex_dtype::{DType, Nullability, PType};
@@ -58,7 +64,7 @@ mod tests {
5864
fn test_cast_zstd_i32_to_i64() {
5965
let values = PrimitiveArray::new(
6066
Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
61-
vortex_array::validity::Validity::NonNullable,
67+
Validity::NonNullable,
6268
);
6369
let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
6470

@@ -80,7 +86,7 @@ mod tests {
8086
fn test_cast_zstd_nullability_change() {
8187
let values = PrimitiveArray::new(
8288
Buffer::copy_from(vec![10u32, 20, 30, 40]),
83-
vortex_array::validity::Validity::NonNullable,
89+
Validity::NonNullable,
8490
);
8591
let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
8692

@@ -95,22 +101,71 @@ mod tests {
95101
);
96102
}
97103

104+
#[test]
105+
fn test_cast_sliced_zstd_nullable_to_nonnullable() {
106+
let values = PrimitiveArray::new(
107+
Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
108+
Validity::from_iter([true, true, true, true, true, true]),
109+
);
110+
let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
111+
let sliced = zstd.slice(1..5);
112+
let casted = cast(
113+
sliced.as_ref(),
114+
&DType::Primitive(PType::U32, Nullability::NonNullable),
115+
)
116+
.unwrap();
117+
assert_eq!(
118+
casted.dtype(),
119+
&DType::Primitive(PType::U32, Nullability::NonNullable)
120+
);
121+
// Verify the values are correct
122+
let decoded = casted.to_primitive();
123+
let u32_values = decoded.as_slice::<u32>();
124+
assert_eq!(u32_values, &[20, 30, 40, 50]);
125+
}
126+
127+
#[test]
128+
fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
129+
let values = PrimitiveArray::from_option_iter([
130+
None,
131+
Some(20u32),
132+
Some(30),
133+
Some(40),
134+
Some(50),
135+
Some(60),
136+
]);
137+
let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
138+
let sliced = zstd.slice(1..5);
139+
let casted = cast(
140+
sliced.as_ref(),
141+
&DType::Primitive(PType::U32, Nullability::NonNullable),
142+
)
143+
.unwrap();
144+
assert_eq!(
145+
casted.dtype(),
146+
&DType::Primitive(PType::U32, Nullability::NonNullable)
147+
);
148+
let decoded = casted.to_primitive();
149+
let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
150+
assert_arrays_eq!(decoded, expected);
151+
}
152+
98153
#[rstest]
99154
#[case::i32(PrimitiveArray::new(
100155
Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
101-
vortex_array::validity::Validity::NonNullable,
156+
Validity::NonNullable,
102157
))]
103158
#[case::f64(PrimitiveArray::new(
104159
Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
105-
vortex_array::validity::Validity::NonNullable,
160+
Validity::NonNullable,
106161
))]
107162
#[case::single(PrimitiveArray::new(
108163
Buffer::copy_from(vec![42i64]),
109-
vortex_array::validity::Validity::NonNullable,
164+
Validity::NonNullable,
110165
))]
111166
#[case::large(PrimitiveArray::new(
112167
Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
113-
vortex_array::validity::Validity::NonNullable,
168+
Validity::NonNullable,
114169
))]
115170
fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
116171
let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();

0 commit comments

Comments
 (0)