|
4 | 4 | use vortex_buffer::{Buffer, BufferMut}; |
5 | 5 | use vortex_dtype::{DType, NativePType, match_each_native_ptype}; |
6 | 6 | use vortex_error::{VortexResult, vortex_err}; |
| 7 | +use vortex_mask::{AllOr, Mask}; |
7 | 8 |
|
8 | 9 | use crate::arrays::PrimitiveVTable; |
9 | 10 | use crate::arrays::primitive::PrimitiveArray; |
@@ -36,37 +37,62 @@ impl CastKernel for PrimitiveVTable { |
36 | 37 | )); |
37 | 38 | } |
38 | 39 |
|
| 40 | + let mask = array.validity_mask(); |
| 41 | + |
39 | 42 | // Otherwise, we need to cast the values one-by-one |
40 | | - match_each_native_ptype!(new_ptype, |T| { |
41 | | - Ok(Some( |
42 | | - PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array(), |
43 | | - )) |
44 | | - }) |
| 43 | + Ok(Some(match_each_native_ptype!(new_ptype, |T| { |
| 44 | + match_each_native_ptype!(array.ptype(), |F| { |
| 45 | + PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity) |
| 46 | + .into_array() |
| 47 | + }) |
| 48 | + }))) |
45 | 49 | } |
46 | 50 | } |
47 | 51 |
|
48 | 52 | register_kernel!(CastKernelAdapter(PrimitiveVTable).lift()); |
49 | 53 |
|
50 | | -fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> { |
51 | | - let mut buffer = BufferMut::with_capacity(array.len()); |
52 | | - match_each_native_ptype!(array.ptype(), |P| { |
53 | | - for item in array.as_slice::<P>() { |
54 | | - let item = T::from(*item).ok_or_else( |
55 | | - || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE), |
56 | | - )?; |
57 | | - // SAFETY: we've pre-allocated the required capacity |
58 | | - unsafe { buffer.push_unchecked(item) } |
| 54 | +fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> { |
| 55 | + match mask.boolean_buffer() { |
| 56 | + AllOr::All => { |
| 57 | + let mut buffer = BufferMut::with_capacity(array.len()); |
| 58 | + for item in array { |
| 59 | + let item = T::from(*item).ok_or_else( |
| 60 | + || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE), |
| 61 | + )?; |
| 62 | + // SAFETY: we've pre-allocated the required capacity |
| 63 | + unsafe { buffer.push_unchecked(item) } |
| 64 | + } |
| 65 | + Ok(buffer.freeze()) |
| 66 | + } |
| 67 | + AllOr::None => Ok(Buffer::zeroed(array.len())), |
| 68 | + AllOr::Some(b) => { |
| 69 | + // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values |
| 70 | + let mut buffer = BufferMut::with_capacity(array.len()); |
| 71 | + for (item, valid) in array.iter().zip(b.iter()) { |
| 72 | + if valid { |
| 73 | + let item = T::from(*item).ok_or_else( |
| 74 | + || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE), |
| 75 | + )?; |
| 76 | + // SAFETY: we've pre-allocated the required capacity |
| 77 | + unsafe { buffer.push_unchecked(item) } |
| 78 | + } else { |
| 79 | + // SAFETY: we've pre-allocated the required capacity |
| 80 | + unsafe { buffer.push_unchecked(T::default()) } |
| 81 | + } |
| 82 | + } |
| 83 | + Ok(buffer.freeze()) |
59 | 84 | } |
60 | | - }); |
61 | | - Ok(buffer.freeze()) |
| 85 | + } |
62 | 86 | } |
63 | 87 |
|
64 | 88 | #[cfg(test)] |
65 | 89 | mod test { |
| 90 | + use arrow_buffer::BooleanBuffer; |
66 | 91 | use rstest::rstest; |
67 | 92 | use vortex_buffer::buffer; |
68 | 93 | use vortex_dtype::{DType, Nullability, PType}; |
69 | 94 | use vortex_error::VortexError; |
| 95 | + use vortex_mask::Mask; |
70 | 96 |
|
71 | 97 | use crate::IntoArray; |
72 | 98 | use crate::arrays::PrimitiveArray; |
@@ -156,6 +182,25 @@ mod test { |
156 | 182 | ); |
157 | 183 | } |
158 | 184 |
|
| 185 | + #[test] |
| 186 | + fn cast_with_invalid_nulls() { |
| 187 | + let arr = PrimitiveArray::new( |
| 188 | + buffer![-1i32, 0, 10], |
| 189 | + Validity::from_iter([false, true, true]), |
| 190 | + ); |
| 191 | + let p = cast( |
| 192 | + arr.as_ref(), |
| 193 | + &DType::Primitive(PType::U32, Nullability::Nullable), |
| 194 | + ) |
| 195 | + .unwrap() |
| 196 | + .to_primitive(); |
| 197 | + assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]); |
| 198 | + assert_eq!( |
| 199 | + p.validity_mask(), |
| 200 | + Mask::from(BooleanBuffer::from(vec![false, true, true])) |
| 201 | + ); |
| 202 | + } |
| 203 | + |
159 | 204 | #[rstest] |
160 | 205 | #[case(buffer![0u8, 1, 2, 3, 255].into_array())] |
161 | 206 | #[case(buffer![0u16, 100, 1000, 65535].into_array())] |
|
0 commit comments