Skip to content

Commit 89760db

Browse files
bug[vortex-array]: take nullability chunked array not respects validity (#3594)
Chunked array current ignores indices validity --------- Signed-off-by: Joe Isaacs <[email protected]>
1 parent 4c88db1 commit 89760db

File tree

2 files changed

+79
-14
lines changed

2 files changed

+79
-14
lines changed

vortex-array/src/arrays/chunked/compute/take.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,70 @@
11
use vortex_buffer::BufferMut;
2-
use vortex_dtype::PType;
2+
use vortex_dtype::{DType, PType};
33
use vortex_error::VortexResult;
44

5-
use crate::arrays::ChunkedVTable;
65
use crate::arrays::chunked::ChunkedArray;
6+
use crate::arrays::{ChunkedVTable, PrimitiveArray};
77
use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
8+
use crate::validity::Validity;
89
use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
910

1011
impl TakeKernel for ChunkedVTable {
1112
fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12-
let indices = cast(indices, PType::U64.into())?.to_primitive()?;
13+
let indices = cast(
14+
indices,
15+
&DType::Primitive(PType::U64, indices.dtype().nullability()),
16+
)?
17+
.to_primitive()?;
18+
19+
// TODO(joe): Should we split this implementation based on indices nullability?
20+
let nullability = indices.dtype().nullability();
21+
let indices_mask = indices.validity_mask()?;
22+
let indices = indices.as_slice::<u64>();
1323

14-
// While the chunk idx remains the same, accumulate a list of chunk indices.
1524
let mut chunks = Vec::new();
1625
let mut indices_in_chunk = BufferMut::<u64>::empty();
17-
let mut prev_chunk_idx = array
18-
.find_chunk_idx(indices.as_slice::<u64>()[0].try_into()?)
19-
.0;
20-
for idx in indices.as_slice::<u64>() {
26+
let mut start = 0;
27+
let mut stop = 0;
28+
let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
29+
for idx in indices {
2130
let idx = usize::try_from(*idx)?;
2231
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
2332

2433
if chunk_idx != prev_chunk_idx {
2534
// Start a new chunk
26-
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
27-
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
35+
let indices_in_chunk_array = PrimitiveArray::new(
36+
indices_in_chunk.clone().freeze(),
37+
Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
38+
);
39+
chunks.push(take(
40+
array.chunk(prev_chunk_idx)?,
41+
indices_in_chunk_array.as_ref(),
42+
)?);
2843
indices_in_chunk.clear();
44+
start = stop;
2945
}
3046

3147
indices_in_chunk.push(idx_in_chunk as u64);
48+
stop += 1;
3249
prev_chunk_idx = chunk_idx;
3350
}
3451

3552
if !indices_in_chunk.is_empty() {
36-
let indices_in_chunk_array = indices_in_chunk.into_array();
37-
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
53+
let indices_in_chunk_array = PrimitiveArray::new(
54+
indices_in_chunk.freeze(),
55+
Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
56+
);
57+
chunks.push(take(
58+
array.chunk(prev_chunk_idx)?,
59+
indices_in_chunk_array.as_ref(),
60+
)?);
3861
}
3962

40-
Ok(ChunkedArray::new_unchecked(chunks, array.dtype().clone()).into_array())
63+
Ok(ChunkedArray::new_unchecked(
64+
chunks,
65+
array.dtype().clone().union_nullability(nullability),
66+
)
67+
.into_array())
4168
}
4269
}
4370

@@ -50,8 +77,10 @@ mod test {
5077
use crate::IntoArray;
5178
use crate::array::Array;
5279
use crate::arrays::chunked::ChunkedArray;
80+
use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
5381
use crate::canonical::ToCanonical;
5482
use crate::compute::take;
83+
use crate::validity::Validity;
5584

5685
#[test]
5786
fn test_take() {
@@ -68,4 +97,30 @@ mod test {
6897
.unwrap();
6998
assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
7099
}
100+
101+
#[test]
102+
fn test_take_nullability() {
103+
let struct_array =
104+
StructArray::try_new([].into(), vec![], 100, Validity::NonNullable).unwrap();
105+
106+
let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
107+
108+
let result = take(
109+
arr.as_ref(),
110+
PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
111+
)
112+
.unwrap();
113+
114+
let expect = StructArray::try_new(
115+
[].into(),
116+
vec![],
117+
3,
118+
Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
119+
)
120+
.unwrap();
121+
assert_eq!(result.dtype(), expect.dtype());
122+
assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap());
123+
assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap());
124+
assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap());
125+
}
71126
}

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use itertools::Itertools;
55
use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
6-
use vortex_error::{VortexResult, vortex_bail, vortex_err};
6+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
77
use vortex_scalar::Scalar;
88

99
use crate::stats::{ArrayStats, StatsSetRef};
@@ -116,6 +116,16 @@ impl StructArray {
116116
nullability,
117117
);
118118

119+
if length != validity.maybe_len().unwrap_or(length) {
120+
vortex_bail!(
121+
"array length {} and validity length must match {}",
122+
length,
123+
validity
124+
.maybe_len()
125+
.vortex_expect("can only fail if maybe is some")
126+
)
127+
}
128+
119129
Ok(Self {
120130
len: length,
121131
dtype,

0 commit comments

Comments
 (0)