Skip to content

Commit 4864e30

Browse files
authored
fix: Casting primitive nullable values doesn't cast invalid values (#4735)
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent 060b2d6 commit 4864e30

File tree

1 file changed

+61
-16
lines changed
  • vortex-array/src/arrays/primitive/compute

1 file changed

+61
-16
lines changed

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use vortex_buffer::{Buffer, BufferMut};
55
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
66
use vortex_error::{VortexResult, vortex_err};
7+
use vortex_mask::{AllOr, Mask};
78

89
use crate::arrays::PrimitiveVTable;
910
use crate::arrays::primitive::PrimitiveArray;
@@ -36,37 +37,62 @@ impl CastKernel for PrimitiveVTable {
3637
));
3738
}
3839

40+
let mask = array.validity_mask();
41+
3942
// Otherwise, we need to cast the values one-by-one
40-
match_each_native_ptype!(new_ptype, |T| {
41-
Ok(Some(
42-
PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array(),
43-
))
44-
})
43+
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
44+
match_each_native_ptype!(array.ptype(), |F| {
45+
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
46+
.into_array()
47+
})
48+
})))
4549
}
4650
}
4751

4852
register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
4953

50-
fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
51-
let mut buffer = BufferMut::with_capacity(array.len());
52-
match_each_native_ptype!(array.ptype(), |P| {
53-
for item in array.as_slice::<P>() {
54-
let item = T::from(*item).ok_or_else(
55-
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
56-
)?;
57-
// SAFETY: we've pre-allocated the required capacity
58-
unsafe { buffer.push_unchecked(item) }
54+
fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
55+
match mask.boolean_buffer() {
56+
AllOr::All => {
57+
let mut buffer = BufferMut::with_capacity(array.len());
58+
for item in array {
59+
let item = T::from(*item).ok_or_else(
60+
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
61+
)?;
62+
// SAFETY: we've pre-allocated the required capacity
63+
unsafe { buffer.push_unchecked(item) }
64+
}
65+
Ok(buffer.freeze())
66+
}
67+
AllOr::None => Ok(Buffer::zeroed(array.len())),
68+
AllOr::Some(b) => {
69+
// TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
70+
let mut buffer = BufferMut::with_capacity(array.len());
71+
for (item, valid) in array.iter().zip(b.iter()) {
72+
if valid {
73+
let item = T::from(*item).ok_or_else(
74+
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
75+
)?;
76+
// SAFETY: we've pre-allocated the required capacity
77+
unsafe { buffer.push_unchecked(item) }
78+
} else {
79+
// SAFETY: we've pre-allocated the required capacity
80+
unsafe { buffer.push_unchecked(T::default()) }
81+
}
82+
}
83+
Ok(buffer.freeze())
5984
}
60-
});
61-
Ok(buffer.freeze())
85+
}
6286
}
6387

6488
#[cfg(test)]
6589
mod test {
90+
use arrow_buffer::BooleanBuffer;
6691
use rstest::rstest;
6792
use vortex_buffer::buffer;
6893
use vortex_dtype::{DType, Nullability, PType};
6994
use vortex_error::VortexError;
95+
use vortex_mask::Mask;
7096

7197
use crate::IntoArray;
7298
use crate::arrays::PrimitiveArray;
@@ -156,6 +182,25 @@ mod test {
156182
);
157183
}
158184

185+
#[test]
186+
fn cast_with_invalid_nulls() {
187+
let arr = PrimitiveArray::new(
188+
buffer![-1i32, 0, 10],
189+
Validity::from_iter([false, true, true]),
190+
);
191+
let p = cast(
192+
arr.as_ref(),
193+
&DType::Primitive(PType::U32, Nullability::Nullable),
194+
)
195+
.unwrap()
196+
.to_primitive();
197+
assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
198+
assert_eq!(
199+
p.validity_mask(),
200+
Mask::from(BooleanBuffer::from(vec![false, true, true]))
201+
);
202+
}
203+
159204
#[rstest]
160205
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
161206
#[case(buffer![0u16, 100, 1000, 65535].into_array())]

0 commit comments

Comments
 (0)