Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 68 additions & 13 deletions vortex-array/src/arrays/chunked/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,70 @@
use vortex_buffer::BufferMut;
use vortex_dtype::PType;
use vortex_dtype::{DType, PType};
use vortex_error::VortexResult;

use crate::arrays::ChunkedVTable;
use crate::arrays::chunked::ChunkedArray;
use crate::arrays::{ChunkedVTable, PrimitiveArray};
use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
use crate::validity::Validity;
use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};

impl TakeKernel for ChunkedVTable {
fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
let indices = cast(indices, PType::U64.into())?.to_primitive()?;
let indices = cast(
indices,
&DType::Primitive(PType::U64, indices.dtype().nullability()),
)?
.to_primitive()?;

// TODO(joe): Should we split this implementation based on indices nullability?
let nullability = indices.dtype().nullability();
let indices_mask = indices.validity_mask()?;
let indices = indices.as_slice::<u64>();

// While the chunk idx remains the same, accumulate a list of chunk indices.
let mut chunks = Vec::new();
let mut indices_in_chunk = BufferMut::<u64>::empty();
let mut prev_chunk_idx = array
.find_chunk_idx(indices.as_slice::<u64>()[0].try_into()?)
.0;
for idx in indices.as_slice::<u64>() {
let mut start = 0;
let mut stop = 0;
let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
for idx in indices {
let idx = usize::try_from(*idx)?;
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);

if chunk_idx != prev_chunk_idx {
// Start a new chunk
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
let indices_in_chunk_array = PrimitiveArray::new(
indices_in_chunk.clone().freeze(),
Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
);
chunks.push(take(
array.chunk(prev_chunk_idx)?,
indices_in_chunk_array.as_ref(),
)?);
indices_in_chunk.clear();
start = stop;
}

indices_in_chunk.push(idx_in_chunk as u64);
stop += 1;
prev_chunk_idx = chunk_idx;
}

if !indices_in_chunk.is_empty() {
let indices_in_chunk_array = indices_in_chunk.into_array();
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
let indices_in_chunk_array = PrimitiveArray::new(
indices_in_chunk.freeze(),
Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
);
chunks.push(take(
array.chunk(prev_chunk_idx)?,
indices_in_chunk_array.as_ref(),
)?);
}

Ok(ChunkedArray::new_unchecked(chunks, array.dtype().clone()).into_array())
Ok(ChunkedArray::new_unchecked(
chunks,
array.dtype().clone().union_nullability(nullability),
)
.into_array())
}
}

Expand All @@ -50,8 +77,10 @@ mod test {
use crate::IntoArray;
use crate::array::Array;
use crate::arrays::chunked::ChunkedArray;
use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
use crate::canonical::ToCanonical;
use crate::compute::take;
use crate::validity::Validity;

#[test]
fn test_take() {
Expand All @@ -68,4 +97,30 @@ mod test {
.unwrap();
assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
}

#[test]
fn test_take_nullability() {
let struct_array =
StructArray::try_new([].into(), vec![], 100, Validity::NonNullable).unwrap();

let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);

let result = take(
arr.as_ref(),
PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
)
.unwrap();

let expect = StructArray::try_new(
[].into(),
vec![],
3,
Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
)
.unwrap();
assert_eq!(result.dtype(), expect.dtype());
assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap());
assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap());
assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap());
}
}
12 changes: 11 additions & 1 deletion vortex-array/src/arrays/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use itertools::Itertools;
use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
use vortex_error::{VortexResult, vortex_bail, vortex_err};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_scalar::Scalar;

use crate::stats::{ArrayStats, StatsSetRef};
Expand Down Expand Up @@ -116,6 +116,16 @@ impl StructArray {
nullability,
);

if length != validity.maybe_len().unwrap_or(length) {
vortex_bail!(
"array length {} and validity length must match {}",
length,
validity
.maybe_len()
.vortex_expect("can only fail if maybe is some")
)
}

Ok(Self {
len: length,
dtype,
Expand Down
Loading