diff --git a/vortex-array/src/arrays/chunked/compute/take.rs b/vortex-array/src/arrays/chunked/compute/take.rs index ef56a31f52a..e1df97cec75 100644 --- a/vortex-array/src/arrays/chunked/compute/take.rs +++ b/vortex-array/src/arrays/chunked/compute/take.rs @@ -1,43 +1,70 @@ use vortex_buffer::BufferMut; -use vortex_dtype::PType; +use vortex_dtype::{DType, PType}; use vortex_error::VortexResult; -use crate::arrays::ChunkedVTable; use crate::arrays::chunked::ChunkedArray; +use crate::arrays::{ChunkedVTable, PrimitiveArray}; use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take}; +use crate::validity::Validity; use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel}; impl TakeKernel for ChunkedVTable { fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult { - let indices = cast(indices, PType::U64.into())?.to_primitive()?; + let indices = cast( + indices, + &DType::Primitive(PType::U64, indices.dtype().nullability()), + )? + .to_primitive()?; + + // TODO(joe): Should we split this implementation based on indices nullability? + let nullability = indices.dtype().nullability(); + let indices_mask = indices.validity_mask()?; + let indices = indices.as_slice::(); - // While the chunk idx remains the same, accumulate a list of chunk indices. let mut chunks = Vec::new(); let mut indices_in_chunk = BufferMut::::empty(); - let mut prev_chunk_idx = array - .find_chunk_idx(indices.as_slice::()[0].try_into()?) - .0; - for idx in indices.as_slice::() { + let mut start = 0; + let mut stop = 0; + let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0; + for idx in indices { let idx = usize::try_from(*idx)?; let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx); if chunk_idx != prev_chunk_idx { // Start a new chunk - let indices_in_chunk_array = indices_in_chunk.clone().into_array(); - chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?); + let indices_in_chunk_array = PrimitiveArray::new( + indices_in_chunk.clone().freeze(), + Validity::from_mask(indices_mask.slice(start, stop - start), nullability), + ); + chunks.push(take( + array.chunk(prev_chunk_idx)?, + indices_in_chunk_array.as_ref(), + )?); indices_in_chunk.clear(); + start = stop; } indices_in_chunk.push(idx_in_chunk as u64); + stop += 1; prev_chunk_idx = chunk_idx; } if !indices_in_chunk.is_empty() { - let indices_in_chunk_array = indices_in_chunk.into_array(); - chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?); + let indices_in_chunk_array = PrimitiveArray::new( + indices_in_chunk.freeze(), + Validity::from_mask(indices_mask.slice(start, stop - start), nullability), + ); + chunks.push(take( + array.chunk(prev_chunk_idx)?, + indices_in_chunk_array.as_ref(), + )?); } - Ok(ChunkedArray::new_unchecked(chunks, array.dtype().clone()).into_array()) + Ok(ChunkedArray::new_unchecked( + chunks, + array.dtype().clone().union_nullability(nullability), + ) + .into_array()) } } @@ -50,8 +77,10 @@ mod test { use crate::IntoArray; use crate::array::Array; use crate::arrays::chunked::ChunkedArray; + use crate::arrays::{BoolArray, PrimitiveArray, StructArray}; use crate::canonical::ToCanonical; use crate::compute::take; + use crate::validity::Validity; #[test] fn test_take() { @@ -68,4 +97,30 @@ mod test { .unwrap(); assert_eq!(result.as_slice::(), &[1, 1, 1, 2]); } + + #[test] + fn test_take_nullability() { + let struct_array = + StructArray::try_new([].into(), vec![], 100, Validity::NonNullable).unwrap(); + + let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]); + + let result = take( + arr.as_ref(), + PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(), + ) + .unwrap(); + + let expect = StructArray::try_new( + [].into(), + vec![], + 3, + Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()), + ) + .unwrap(); + assert_eq!(result.dtype(), expect.dtype()); + assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap()); + assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap()); + assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap()); + } } diff --git a/vortex-array/src/arrays/struct_/mod.rs b/vortex-array/src/arrays/struct_/mod.rs index 4db30819bdb..17e78d3c27e 100644 --- a/vortex-array/src/arrays/struct_/mod.rs +++ b/vortex-array/src/arrays/struct_/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use itertools::Itertools; use vortex_dtype::{DType, FieldName, FieldNames, StructFields}; -use vortex_error::{VortexResult, vortex_bail, vortex_err}; +use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err}; use vortex_scalar::Scalar; use crate::stats::{ArrayStats, StatsSetRef}; @@ -116,6 +116,16 @@ impl StructArray { nullability, ); + if length != validity.maybe_len().unwrap_or(length) { + vortex_bail!( + "array length {} and validity length must match {}", + length, + validity + .maybe_len() + .vortex_expect("can only fail if maybe is some") + ) + } + Ok(Self { len: length, dtype,