Skip to content

Commit 1f7a888

Browse files
authored
chunked array zip (#5530)
chunked array to push zip down to chunks --------- Signed-off-by: Onur Satici <[email protected]>
1 parent 2f80934 commit 1f7a888

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

vortex-array/src/arrays/chunked/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod mask;
1212
mod min_max;
1313
mod sum;
1414
mod take;
15+
mod zip;
1516

1617
#[cfg(test)]
1718
mod tests {
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_mask::Mask;
6+
7+
use crate::Array;
8+
use crate::ArrayRef;
9+
use crate::arrays::ChunkedArray;
10+
use crate::arrays::ChunkedVTable;
11+
use crate::compute::ZipKernel;
12+
use crate::compute::ZipKernelAdapter;
13+
use crate::compute::zip;
14+
use crate::register_kernel;
15+
16+
// Push down the zip call to the chunks. Without this kernel
17+
// the default implementation canonicalises the chunked array
18+
// then zips once.
19+
impl ZipKernel for ChunkedVTable {
20+
fn zip(
21+
&self,
22+
if_true: &ChunkedArray,
23+
if_false: &dyn Array,
24+
mask: &Mask,
25+
) -> VortexResult<Option<ArrayRef>> {
26+
let Some(if_false) = if_false.as_opt::<ChunkedVTable>() else {
27+
return Ok(None);
28+
};
29+
let dtype = if_true
30+
.dtype()
31+
.union_nullability(if_false.dtype().nullability());
32+
let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
33+
34+
let mut lhs_idx = 0;
35+
let mut rhs_idx = 0;
36+
let mut lhs_offset = 0;
37+
let mut rhs_offset = 0;
38+
let mut pos = 0;
39+
let total_len = if_true.len();
40+
41+
while pos < total_len {
42+
let lhs_chunk = if_true.chunk(lhs_idx);
43+
let rhs_chunk = if_false.chunk(rhs_idx);
44+
45+
let lhs_rem = lhs_chunk.len() - lhs_offset;
46+
let rhs_rem = rhs_chunk.len() - rhs_offset;
47+
let take_until = lhs_rem.min(rhs_rem);
48+
49+
let mask_slice = mask.slice(pos..pos + take_until);
50+
let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until);
51+
let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until);
52+
53+
out_chunks.push(zip(lhs_slice.as_ref(), rhs_slice.as_ref(), &mask_slice)?);
54+
55+
pos += take_until;
56+
lhs_offset += take_until;
57+
rhs_offset += take_until;
58+
59+
if lhs_offset == lhs_chunk.len() {
60+
lhs_idx += 1;
61+
lhs_offset = 0;
62+
}
63+
if rhs_offset == rhs_chunk.len() {
64+
rhs_idx += 1;
65+
rhs_offset = 0;
66+
}
67+
}
68+
69+
// SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
70+
let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
71+
Ok(Some(chunked.to_array()))
72+
}
73+
}
74+
75+
register_kernel!(ZipKernelAdapter(ChunkedVTable).lift());
76+
77+
#[cfg(test)]
78+
mod tests {
79+
use vortex_buffer::buffer;
80+
use vortex_dtype::DType;
81+
use vortex_dtype::Nullability;
82+
use vortex_dtype::PType;
83+
use vortex_mask::Mask;
84+
85+
use crate::IntoArray;
86+
use crate::ToCanonical;
87+
use crate::arrays::ChunkedArray;
88+
use crate::arrays::ChunkedVTable;
89+
use crate::compute::zip;
90+
91+
#[test]
92+
fn test_chunked_zip_aligns_across_boundaries() {
93+
let if_true = ChunkedArray::try_new(
94+
vec![
95+
buffer![1i32, 2].into_array(),
96+
buffer![3i32].into_array(),
97+
buffer![4i32, 5].into_array(),
98+
],
99+
DType::Primitive(PType::I32, Nullability::NonNullable),
100+
)
101+
.unwrap();
102+
103+
let if_false = ChunkedArray::try_new(
104+
vec![
105+
buffer![10i32].into_array(),
106+
buffer![11i32, 12].into_array(),
107+
buffer![13i32, 14].into_array(),
108+
],
109+
DType::Primitive(PType::I32, Nullability::NonNullable),
110+
)
111+
.unwrap();
112+
113+
let mask = Mask::from_iter([true, false, true, false, true]);
114+
115+
let zipped = zip(if_true.as_ref(), if_false.as_ref(), &mask).unwrap();
116+
let zipped = zipped
117+
.as_opt::<ChunkedVTable>()
118+
.expect("zip should keep chunked encoding");
119+
120+
assert_eq!(zipped.nchunks(), 4);
121+
let mut values: Vec<i32> = Vec::new();
122+
for chunk in zipped.chunks() {
123+
let primitive = chunk.to_primitive();
124+
values.extend_from_slice(primitive.as_slice::<i32>());
125+
}
126+
assert_eq!(values, vec![1, 11, 3, 13, 5]);
127+
}
128+
}

0 commit comments

Comments
 (0)