Skip to content

Commit aabd2aa

Browse files
committed
chunked array zip
Signed-off-by: Onur Satici <[email protected]>
1 parent 6b514f2 commit aabd2aa

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

vortex-array/src/arrays/chunked/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod mask;
1212
mod min_max;
1313
mod sum;
1414
mod take;
15+
mod zip;
1516

1617
#[cfg(test)]
1718
mod tests {
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_mask::Mask;
6+
7+
use crate::Array;
8+
use crate::ArrayRef;
9+
use crate::arrays::ChunkedArray;
10+
use crate::arrays::ChunkedVTable;
11+
use crate::compute::ZipKernel;
12+
use crate::compute::ZipKernelAdapter;
13+
use crate::compute::zip;
14+
use crate::register_kernel;
15+
16+
impl ZipKernel for ChunkedVTable {
17+
fn zip(
18+
&self,
19+
if_true: &ChunkedArray,
20+
if_false: &dyn Array,
21+
mask: &Mask,
22+
) -> VortexResult<Option<ArrayRef>> {
23+
let Some(if_false) = if_false.as_opt::<ChunkedVTable>() else {
24+
return Ok(None);
25+
};
26+
let dtype = if_true
27+
.dtype()
28+
.union_nullability(if_false.dtype().nullability());
29+
let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
30+
31+
let mut lhs_idx = 0;
32+
let mut rhs_idx = 0;
33+
let mut lhs_offset = 0;
34+
let mut rhs_offset = 0;
35+
let mut pos = 0;
36+
let total_len = if_true.len();
37+
38+
while pos < total_len {
39+
let lhs_chunk = if_true.chunk(lhs_idx);
40+
let rhs_chunk = if_false.chunk(rhs_idx);
41+
42+
let lhs_rem = lhs_chunk.len() - lhs_offset;
43+
let rhs_rem = rhs_chunk.len() - rhs_offset;
44+
let take_until = lhs_rem.min(rhs_rem);
45+
46+
let mask_slice = mask.slice(pos..pos + take_until);
47+
let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until);
48+
let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until);
49+
50+
out_chunks.push(zip(lhs_slice.as_ref(), rhs_slice.as_ref(), &mask_slice)?);
51+
52+
pos += take_until;
53+
lhs_offset += take_until;
54+
rhs_offset += take_until;
55+
56+
if lhs_offset == lhs_chunk.len() {
57+
lhs_idx += 1;
58+
lhs_offset = 0;
59+
}
60+
if rhs_offset == rhs_chunk.len() {
61+
rhs_idx += 1;
62+
rhs_offset = 0;
63+
}
64+
}
65+
66+
// SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
67+
let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
68+
Ok(Some(chunked.to_array()))
69+
}
70+
}
71+
72+
register_kernel!(ZipKernelAdapter(ChunkedVTable).lift());
73+
74+
#[cfg(test)]
75+
mod tests {
76+
use vortex_buffer::buffer;
77+
use vortex_dtype::DType;
78+
use vortex_dtype::Nullability;
79+
use vortex_dtype::PType;
80+
use vortex_mask::Mask;
81+
82+
use crate::IntoArray;
83+
use crate::ToCanonical;
84+
use crate::arrays::ChunkedArray;
85+
use crate::arrays::ChunkedVTable;
86+
use crate::compute::zip;
87+
88+
#[test]
89+
fn test_chunked_zip_aligns_across_boundaries() {
90+
let if_true = ChunkedArray::try_new(
91+
vec![
92+
buffer![1i32, 2].into_array(),
93+
buffer![3i32].into_array(),
94+
buffer![4i32, 5].into_array(),
95+
],
96+
DType::Primitive(PType::I32, Nullability::NonNullable),
97+
)
98+
.unwrap();
99+
100+
let if_false = ChunkedArray::try_new(
101+
vec![
102+
buffer![10i32].into_array(),
103+
buffer![11i32, 12].into_array(),
104+
buffer![13i32, 14].into_array(),
105+
],
106+
DType::Primitive(PType::I32, Nullability::NonNullable),
107+
)
108+
.unwrap();
109+
110+
let mask = Mask::from_iter([true, false, true, false, true]);
111+
112+
let zipped = zip(if_true.as_ref(), if_false.as_ref(), &mask).unwrap();
113+
let zipped = zipped
114+
.as_opt::<ChunkedVTable>()
115+
.expect("zip should keep chunked encoding");
116+
117+
assert_eq!(zipped.nchunks(), 4);
118+
let mut values: Vec<i32> = Vec::new();
119+
for chunk in zipped.chunks() {
120+
let primitive = chunk.to_primitive();
121+
values.extend_from_slice(primitive.as_slice::<i32>());
122+
}
123+
assert_eq!(values, vec![1, 11, 3, 13, 5]);
124+
}
125+
}

0 commit comments

Comments
 (0)