|
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | 4 | use vortex_array::compute::{CastKernel, CastKernelAdapter}; |
5 | | -use vortex_array::{ArrayRef, IntoArray, register_kernel}; |
6 | | -use vortex_dtype::DType; |
| 5 | +use vortex_array::{ArrayRef, register_kernel}; |
| 6 | +use vortex_dtype::{DType, Nullability}; |
7 | 7 | use vortex_error::VortexResult; |
8 | 8 |
|
9 | 9 | use crate::{ZstdArray, ZstdVTable}; |
10 | 10 |
|
11 | 11 | impl CastKernel for ZstdVTable { |
12 | 12 | fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> { |
13 | | - if !dtype.is_nullable() || !array.all_valid() { |
14 | | - // We cannot cast to non-nullable since the validity containing nulls is used to decode |
15 | | - // the ZSTD array, this would require rewriting tables. |
| 13 | + if !dtype.eq_ignore_nullability(array.dtype()) { |
| 14 | + // Type changes can't be handled in ZSTD, need to decode and tweak. |
| 15 | + // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc. |
16 | 16 | return Ok(None); |
17 | 17 | } |
18 | | - // ZstdArray is a general-purpose compression encoding using Zstandard compression. |
19 | | - // It can handle nullability changes without decompression by updating the validity |
20 | | - // bitmap, but type changes require decompression since the compressed data is |
21 | | - // type-specific and Zstd operates on raw bytes. |
22 | | - if array.dtype().eq_ignore_nullability(dtype) { |
23 | | - // Create a new validity with the target nullability |
24 | | - let new_validity = array |
25 | | - .unsliced_validity |
26 | | - .clone() |
27 | | - .cast_nullability(dtype.nullability(), array.len())?; |
28 | | - |
29 | | - return Ok(Some( |
| 18 | + |
| 19 | + let src_nullability = array.dtype().nullability(); |
| 20 | + let target_nullability = dtype.nullability(); |
| 21 | + |
| 22 | + match (src_nullability, target_nullability) { |
| 23 | + // Same type case. This should be handled in the layer above but for |
| 24 | + // completeness of the match arms we also handle it here. |
| 25 | + (Nullability::Nullable, Nullability::Nullable) |
| 26 | + | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())), |
| 27 | + (Nullability::NonNullable, Nullability::Nullable) => Ok(Some( |
| 28 | + // nonnull => null, trivial cast by altering the validity |
30 | 29 | ZstdArray::new( |
31 | 30 | array.dictionary.clone(), |
32 | 31 | array.frames.clone(), |
33 | 32 | dtype.clone(), |
34 | 33 | array.metadata.clone(), |
35 | 34 | array.unsliced_n_rows(), |
36 | | - new_validity, |
| 35 | + array.unsliced_validity.clone(), |
37 | 36 | ) |
38 | | - ._slice(array.slice_start(), array.slice_stop()) |
39 | | - .into_array(), |
40 | | - )); |
| 37 | + .slice(array.slice_start()..array.slice_stop()), |
| 38 | + )), |
| 39 | + (Nullability::Nullable, Nullability::NonNullable) => { |
| 40 | + // null => non-null works if there are no nulls in the sliced range |
| 41 | + let sliced_len = array.slice_stop() - array.slice_start(); |
| 42 | + let has_nulls = !array |
| 43 | + .unsliced_validity |
| 44 | + .slice(array.slice_start()..array.slice_stop()) |
| 45 | + .all_valid(sliced_len); |
| 46 | + |
| 47 | + // We don't attempt to handle casting when there are nulls. |
| 48 | + if has_nulls { |
| 49 | + return Ok(None); |
| 50 | + } |
| 51 | + |
| 52 | + // If there are no nulls, the cast is trivial |
| 53 | + Ok(Some( |
| 54 | + ZstdArray::new( |
| 55 | + array.dictionary.clone(), |
| 56 | + array.frames.clone(), |
| 57 | + dtype.clone(), |
| 58 | + array.metadata.clone(), |
| 59 | + array.unsliced_n_rows(), |
| 60 | + array.unsliced_validity.clone(), |
| 61 | + ) |
| 62 | + .slice(array.slice_start()..array.slice_stop()), |
| 63 | + )) |
| 64 | + } |
41 | 65 | } |
42 | | - |
43 | | - // For other casts (e.g., type changes), decode to canonical and let the underlying array handle it |
44 | | - Ok(None) |
45 | 66 | } |
46 | 67 | } |
47 | 68 |
|
|
0 commit comments