Skip to content

Commit 3aa75c0

Browse files
committed
fix
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 46b29c2 commit 3aa75c0

File tree

2 files changed

+101
-34
lines changed

2 files changed

+101
-34
lines changed

vortex-array/src/arrays/chunked/compute/take.rs

Lines changed: 90 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use vortex_buffer::BufferMut;
2-
use vortex_dtype::{DType, PType};
2+
use vortex_dtype::{DType, Nullability, PType};
33
use vortex_error::VortexResult;
4+
use vortex_mask::Mask;
45

5-
use crate::arrays::ChunkedVTable;
66
use crate::arrays::chunked::ChunkedArray;
7+
use crate::arrays::{ChunkedVTable, PrimitiveArray};
78
use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
9+
use crate::validity::Validity;
810
use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
911

1012
impl TakeKernel for ChunkedVTable {
@@ -15,41 +17,95 @@ impl TakeKernel for ChunkedVTable {
1517
)?
1618
.to_primitive()?;
1719

18-
// While the chunk idx remains the same, accumulate a list of chunk indices.
19-
let mut chunks = Vec::new();
20-
let mut indices_in_chunk = BufferMut::<u64>::empty();
21-
let mut prev_chunk_idx = array
22-
.find_chunk_idx(indices.as_slice::<u64>()[0].try_into()?)
23-
.0;
24-
for idx in indices.as_slice::<u64>() {
25-
let idx = usize::try_from(*idx)?;
26-
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
27-
28-
if chunk_idx != prev_chunk_idx {
29-
// Start a new chunk
30-
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
31-
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
32-
indices_in_chunk.clear();
33-
}
34-
35-
indices_in_chunk.push(idx_in_chunk as u64);
36-
prev_chunk_idx = chunk_idx;
20+
if indices.dtype().is_nullable() {
21+
take_nullable(array, indices.as_slice::<u64>(), indices.validity_mask()?)
22+
} else {
23+
take_non_nullable(array, indices.as_slice::<u64>())
3724
}
25+
}
26+
}
27+
28+
fn take_nullable(
29+
array: &ChunkedArray,
30+
indices: &[u64],
31+
indices_validity: Mask,
32+
) -> VortexResult<ArrayRef> {
33+
// While the chunk idx remains the same, accumulate a list of chunk indices.
34+
let mut chunks = Vec::new();
35+
let mut indices_in_chunk = BufferMut::<u64>::empty();
36+
let mut start = 0;
37+
let mut stop = 0;
38+
let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
39+
for idx in indices {
40+
let idx = usize::try_from(*idx)?;
41+
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
42+
43+
if chunk_idx != prev_chunk_idx {
44+
// Start a new chunk
45+
let indices_in_chunk_array = PrimitiveArray::new(
46+
indices_in_chunk.clone().freeze(),
47+
Validity::Array(indices_validity.slice(start, stop - start).into_array()),
48+
);
49+
chunks.push(take(
50+
array.chunk(prev_chunk_idx)?,
51+
indices_in_chunk_array.as_ref(),
52+
)?);
53+
indices_in_chunk.clear();
54+
start = stop;
55+
}
56+
57+
indices_in_chunk.push(idx_in_chunk as u64);
58+
stop += 1;
59+
prev_chunk_idx = chunk_idx;
60+
}
61+
62+
if !indices_in_chunk.is_empty() {
63+
let indices_in_chunk_array = PrimitiveArray::new(
64+
indices_in_chunk.freeze(),
65+
Validity::Array(indices_validity.slice(start, stop - start).into_array()),
66+
);
67+
chunks.push(take(
68+
array.chunk(prev_chunk_idx)?,
69+
indices_in_chunk_array.as_ref(),
70+
)?);
71+
}
72+
73+
Ok(ChunkedArray::new_unchecked(
74+
chunks,
75+
array
76+
.dtype()
77+
.clone()
78+
.union_nullability(Nullability::Nullable),
79+
)
80+
.into_array())
81+
}
3882

39-
if !indices_in_chunk.is_empty() {
40-
let indices_in_chunk_array = indices_in_chunk.into_array();
83+
fn take_non_nullable(array: &ChunkedArray, indices: &[u64]) -> VortexResult<ArrayRef> {
84+
// While the chunk idx remains the same, accumulate a list of chunk indices.
85+
let mut chunks = Vec::new();
86+
let mut indices_in_chunk = BufferMut::<u64>::empty();
87+
let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
88+
for idx in indices {
89+
let idx = usize::try_from(*idx)?;
90+
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
91+
92+
if chunk_idx != prev_chunk_idx {
93+
// Start a new chunk
94+
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
4195
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
96+
indices_in_chunk.clear();
4297
}
4398

44-
Ok(ChunkedArray::new_unchecked(
45-
chunks,
46-
array
47-
.dtype()
48-
.clone()
49-
.union_nullability(indices.dtype().nullability()),
50-
)
51-
.into_array())
99+
indices_in_chunk.push(idx_in_chunk as u64);
100+
prev_chunk_idx = chunk_idx;
52101
}
102+
103+
if !indices_in_chunk.is_empty() {
104+
let indices_in_chunk_array = indices_in_chunk.into_array();
105+
chunks.push(take(array.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
106+
}
107+
108+
Ok(ChunkedArray::new_unchecked(chunks, array.dtype().clone()).into_array())
53109
}
54110

55111
register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
@@ -98,12 +154,13 @@ mod test {
98154
let expect = StructArray::try_new(
99155
[].into(),
100156
vec![],
101-
2,
102-
Validity::Array(BoolArray::from_iter(vec![true, false]).to_array()),
157+
3,
158+
Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
103159
)
104160
.unwrap();
105161
assert_eq!(result.dtype(), expect.dtype());
106162
assert_eq!(result.scalar_at(0).unwrap(), expect.scalar_at(0).unwrap());
107163
assert_eq!(result.scalar_at(1).unwrap(), expect.scalar_at(1).unwrap());
164+
assert_eq!(result.scalar_at(2).unwrap(), expect.scalar_at(2).unwrap());
108165
}
109166
}

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use itertools::Itertools;
55
use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
6-
use vortex_error::{VortexResult, vortex_bail, vortex_err};
6+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
77
use vortex_scalar::Scalar;
88

99
use crate::stats::{ArrayStats, StatsSetRef};
@@ -116,6 +116,16 @@ impl StructArray {
116116
nullability,
117117
);
118118

119+
if length != validity.maybe_len().unwrap_or(length) {
120+
vortex_bail!(
121+
"array length {} and validity length must match {}",
122+
length,
123+
validity
124+
.maybe_len()
125+
.vortex_expect("can only fail if maybe is some")
126+
)
127+
}
128+
119129
Ok(Self {
120130
len: length,
121131
dtype,

0 commit comments

Comments
 (0)