Skip to content

Commit 47c0dea

Browse files
fix[pco]: cast condition check (#5239)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent a5a6864 commit 47c0dea

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

encodings/pco/src/compute/cast.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{PcoArray, PcoVTable};
1010

1111
impl CastKernel for PcoVTable {
1212
fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13-
if !dtype.is_nullable() && !array.all_valid() {
13+
if !dtype.is_nullable() || !array.all_valid() {
1414
// TODO(joe): fixme
1515
// We cannot cast to non-nullable since the validity containing nulls is used to decode
1616
// the PCO array, this would require rewriting tables.
@@ -51,11 +51,11 @@ register_kernel!(CastKernelAdapter(PcoVTable).lift());
5151
#[cfg(test)]
5252
mod tests {
5353
use rstest::rstest;
54-
use vortex_array::ToCanonical;
5554
use vortex_array::arrays::PrimitiveArray;
5655
use vortex_array::compute::cast;
5756
use vortex_array::compute::conformance::cast::test_cast_conformance;
5857
use vortex_array::validity::Validity;
58+
use vortex_array::{ToCanonical, assert_arrays_eq};
5959
use vortex_buffer::Buffer;
6060
use vortex_dtype::{DType, Nullability, PType};
6161

@@ -128,6 +128,32 @@ mod tests {
128128
assert_eq!(u32_values, &[20, 30, 40, 50]);
129129
}
130130

131+
#[test]
132+
fn test_cast_sliced_pco_part_valid_to_nonnullable() {
133+
let values = PrimitiveArray::from_option_iter([
134+
None,
135+
Some(20u32),
136+
Some(30),
137+
Some(40),
138+
Some(50),
139+
Some(60),
140+
]);
141+
let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
142+
let sliced = pco.slice(1..5);
143+
let casted = cast(
144+
sliced.as_ref(),
145+
&DType::Primitive(PType::U32, Nullability::NonNullable),
146+
)
147+
.unwrap();
148+
assert_eq!(
149+
casted.dtype(),
150+
&DType::Primitive(PType::U32, Nullability::NonNullable)
151+
);
152+
let decoded = casted.to_primitive();
153+
let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
154+
assert_arrays_eq!(decoded, expected);
155+
}
156+
131157
#[rstest]
132158
#[case::f32(PrimitiveArray::new(
133159
Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),

0 commit comments

Comments
 (0)