Skip to content

Commit 28f5b3d

Browse files
fix[pco]: cast array with validity (#5193)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 0962924 commit 28f5b3d

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

encodings/pco/src/array.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,43 @@ impl OperationsVTable<PcoVTable> for PcoVTable {
427427
array._slice(index, index + 1).decompress().scalar_at(0)
428428
}
429429
}
430+
431+
#[cfg(test)]
432+
mod tests {
433+
use vortex_array::arrays::PrimitiveArray;
434+
use vortex_array::validity::Validity;
435+
use vortex_array::{IntoArray, ToCanonical, assert_arrays_eq};
436+
use vortex_buffer::Buffer;
437+
438+
use crate::PcoArray;
439+
440+
#[test]
441+
fn test_slice_nullable() {
442+
// Create a nullable array with some nulls
443+
let values = PrimitiveArray::new(
444+
Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
445+
Validity::from_iter([false, true, true, true, true, false]),
446+
);
447+
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
448+
let decoded = pco.to_primitive();
449+
assert_arrays_eq!(
450+
decoded,
451+
PrimitiveArray::from_option_iter([
452+
None,
453+
Some(20u32),
454+
Some(30),
455+
Some(40),
456+
Some(50),
457+
None
458+
])
459+
);
460+
461+
// Slice to get only the non-null values in the middle
462+
let sliced = pco.slice(1..5);
463+
let expected =
464+
PrimitiveArray::from_option_iter([Some(20u32), Some(30), Some(40), Some(50)])
465+
.into_array();
466+
assert_arrays_eq!(sliced, expected);
467+
assert_arrays_eq!(sliced.to_canonical().into_array(), expected);
468+
}
469+
}

encodings/pco/src/compute/cast.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ use crate::{PcoArray, PcoVTable};
1010

1111
impl CastKernel for PcoVTable {
1212
fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13+
if !dtype.is_nullable() && !array.all_valid() {
14+
// TODO(joe): fixme
15+
// We cannot cast to non-nullable since the validity containing nulls is used to decode
16+
// the PCO array, this would require rewriting tables.
17+
return Ok(None);
18+
}
1319
// PCO (Pcodec) is a compression encoding that stores data in a compressed format.
1420
// It can efficiently handle nullability changes without decompression, but type changes
1521
// require decompression since the compression algorithm is type-specific.
@@ -49,6 +55,7 @@ mod tests {
4955
use vortex_array::arrays::PrimitiveArray;
5056
use vortex_array::compute::cast;
5157
use vortex_array::compute::conformance::cast::test_cast_conformance;
58+
use vortex_array::validity::Validity;
5259
use vortex_buffer::Buffer;
5360
use vortex_dtype::{DType, Nullability, PType};
5461

@@ -58,7 +65,7 @@ mod tests {
5865
fn test_cast_pco_f32_to_f64() {
5966
let values = PrimitiveArray::new(
6067
Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
61-
vortex_array::validity::Validity::NonNullable,
68+
Validity::NonNullable,
6269
);
6370
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
6471

@@ -83,7 +90,7 @@ mod tests {
8390
// Test casting from NonNullable to Nullable
8491
let values = PrimitiveArray::new(
8592
Buffer::copy_from(vec![10u32, 20, 30, 40]),
86-
vortex_array::validity::Validity::NonNullable,
93+
Validity::NonNullable,
8794
);
8895
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
8996

@@ -98,26 +105,49 @@ mod tests {
98105
);
99106
}
100107

108+
#[test]
109+
fn test_cast_sliced_pco_nullable_to_nonnullable() {
110+
let values = PrimitiveArray::new(
111+
Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
112+
Validity::from_iter([true, true, true, true, true, true]),
113+
);
114+
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
115+
let sliced = pco.slice(1..5);
116+
let casted = cast(
117+
sliced.as_ref(),
118+
&DType::Primitive(PType::U32, Nullability::NonNullable),
119+
)
120+
.unwrap();
121+
assert_eq!(
122+
casted.dtype(),
123+
&DType::Primitive(PType::U32, Nullability::NonNullable)
124+
);
125+
// Verify the values are correct
126+
let decoded = casted.to_primitive();
127+
let u32_values = decoded.as_slice::<u32>();
128+
assert_eq!(u32_values, &[20, 30, 40, 50]);
129+
}
130+
101131
#[rstest]
102132
#[case::f32(PrimitiveArray::new(
103133
Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
104-
vortex_array::validity::Validity::NonNullable,
134+
Validity::NonNullable,
105135
))]
106136
#[case::f64(PrimitiveArray::new(
107137
Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
108-
vortex_array::validity::Validity::NonNullable,
138+
Validity::NonNullable,
109139
))]
110140
#[case::i32(PrimitiveArray::new(
111141
Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
112-
vortex_array::validity::Validity::NonNullable,
142+
Validity::NonNullable,
113143
))]
114144
#[case::u64(PrimitiveArray::new(
115145
Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
116-
vortex_array::validity::Validity::NonNullable,
146+
Validity::NonNullable,
117147
))]
118148
#[case::single(PrimitiveArray::new(
119149
Buffer::copy_from(vec![42.42f64]),
120-
vortex_array::validity::Validity::NonNullable,
150+
Validity::NonNullable,
121151
))]
122152
fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
123153
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();

fuzz/src/array/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ pub enum Action {
8787
ScalarAt(Vec<usize>),
8888
}
8989

90-
#[derive(Debug)]
90+
#[derive(Debug, Clone)]
9191
pub enum ExpectedValue {
9292
Array(ArrayRef),
9393
Search(SearchResult),

0 commit comments

Comments
 (0)