Skip to content

Commit 6b514f2

Browse files
authored
varbinview zip kernel (#5526)
Signed-off-by: Onur Satici <[email protected]>
1 parent 8261ebf commit 6b514f2

File tree

6 files changed

+286
-32
lines changed

6 files changed

+286
-32
lines changed

vortex-array/src/arrays/struct_/compute/zip.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,10 @@ impl ZipKernel for StructVTable {
2929
let Some(if_false) = if_false.as_opt::<StructVTable>() else {
3030
return Ok(None);
3131
};
32-
assert_eq!(
33-
if_true.len(),
34-
if_false.len(),
35-
"ComputeFn::invoke checks that arrays have the same size"
36-
);
3732
assert_eq!(
3833
if_true.names(),
3934
if_false.names(),
40-
"Zip checks that arrays type"
35+
"input arrays to zip must have the same field names",
4136
);
4237

4338
let fields = if_true

vortex-array/src/arrays/varbinview/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod is_sorted;
88
mod mask;
99
mod min_max;
1010
mod take;
11+
mod zip;
1112

1213
#[cfg(test)]
1314
mod tests {
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ops::Range;
5+
6+
use vortex_buffer::BufferMut;
7+
use vortex_error::VortexResult;
8+
use vortex_error::vortex_bail;
9+
use vortex_mask::AllOr;
10+
use vortex_mask::Mask;
11+
use vortex_vector::binaryview::BinaryView;
12+
13+
use crate::Array;
14+
use crate::ArrayRef;
15+
use crate::arrays::VarBinViewArray;
16+
use crate::arrays::VarBinViewVTable;
17+
use crate::builders::DeduplicatedBuffers;
18+
use crate::builders::LazyBitBufferBuilder;
19+
use crate::compute::ZipKernel;
20+
use crate::compute::ZipKernelAdapter;
21+
use crate::register_kernel;
22+
23+
// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
24+
// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
25+
impl ZipKernel for VarBinViewVTable {
26+
fn zip(
27+
&self,
28+
if_true: &VarBinViewArray,
29+
if_false: &dyn Array,
30+
mask: &Mask,
31+
) -> VortexResult<Option<ArrayRef>> {
32+
let Some(if_false) = if_false.as_opt::<VarBinViewVTable>() else {
33+
return Ok(None);
34+
};
35+
36+
if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
37+
vortex_bail!("input arrays to zip must have the same dtype");
38+
}
39+
40+
// compute fn already asserts if_true.len() == if_false.len()
41+
let len = if_true.len();
42+
let dtype = if_true
43+
.dtype()
44+
.union_nullability(if_false.dtype().nullability());
45+
46+
// build buffer lookup tables for both arrays, these map from the original buffer idx
47+
// to the new buffer index in the result array
48+
let mut buffers = DeduplicatedBuffers::default();
49+
let true_lookup = buffers.extend_from_slice(if_true.buffers());
50+
let false_lookup = buffers.extend_from_slice(if_false.buffers());
51+
52+
let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
53+
let mut validity_builder = LazyBitBufferBuilder::new(len);
54+
55+
let true_validity = if_true.validity_mask();
56+
let false_validity = if_false.validity_mask();
57+
58+
match mask.slices() {
59+
AllOr::All => push_range(
60+
if_true,
61+
&true_lookup,
62+
&true_validity,
63+
0..len,
64+
&mut views_builder,
65+
&mut validity_builder,
66+
),
67+
AllOr::None => push_range(
68+
if_false,
69+
&false_lookup,
70+
&false_validity,
71+
0..len,
72+
&mut views_builder,
73+
&mut validity_builder,
74+
),
75+
AllOr::Some(slices) => {
76+
let mut pos = 0;
77+
for (start, end) in slices {
78+
if pos < *start {
79+
push_range(
80+
if_false,
81+
&false_lookup,
82+
&false_validity,
83+
pos..*start,
84+
&mut views_builder,
85+
&mut validity_builder,
86+
);
87+
}
88+
push_range(
89+
if_true,
90+
&true_lookup,
91+
&true_validity,
92+
*start..*end,
93+
&mut views_builder,
94+
&mut validity_builder,
95+
);
96+
pos = *end;
97+
}
98+
if pos < len {
99+
push_range(
100+
if_false,
101+
&false_lookup,
102+
&false_validity,
103+
pos..len,
104+
&mut views_builder,
105+
&mut validity_builder,
106+
);
107+
}
108+
}
109+
}
110+
111+
let validity = validity_builder.finish_with_nullability(dtype.nullability());
112+
113+
// SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
114+
// buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
115+
let array = unsafe {
116+
VarBinViewArray::new_unchecked(
117+
views_builder.freeze(),
118+
buffers.finish(),
119+
dtype,
120+
validity,
121+
)
122+
};
123+
124+
Ok(Some(array.to_array()))
125+
}
126+
}
127+
128+
fn push_range(
129+
array: &VarBinViewArray,
130+
buffer_lookup: &[u32],
131+
validity: &Mask,
132+
range: Range<usize>,
133+
views_builder: &mut BufferMut<BinaryView>,
134+
validity_builder: &mut LazyBitBufferBuilder,
135+
) {
136+
let views = array.views();
137+
138+
match validity.bit_buffer() {
139+
AllOr::All => {
140+
for idx in range {
141+
push_view(
142+
views[idx],
143+
buffer_lookup,
144+
true,
145+
views_builder,
146+
validity_builder,
147+
);
148+
}
149+
}
150+
AllOr::None => {
151+
for _ in range {
152+
push_view(
153+
BinaryView::empty_view(),
154+
buffer_lookup,
155+
false,
156+
views_builder,
157+
validity_builder,
158+
);
159+
}
160+
}
161+
AllOr::Some(bit_buffer) => {
162+
for idx in range {
163+
let is_valid = bit_buffer.value(idx);
164+
push_view(
165+
views[idx],
166+
buffer_lookup,
167+
is_valid,
168+
views_builder,
169+
validity_builder,
170+
);
171+
}
172+
}
173+
}
174+
}
175+
176+
#[inline]
177+
fn push_view(
178+
view: BinaryView,
179+
buffer_lookup: &[u32],
180+
is_valid: bool,
181+
views_builder: &mut BufferMut<BinaryView>,
182+
validity_builder: &mut LazyBitBufferBuilder,
183+
) {
184+
if !is_valid {
185+
views_builder.push(BinaryView::empty_view());
186+
validity_builder.append_null();
187+
return;
188+
}
189+
190+
let adjusted = if view.is_inlined() {
191+
view
192+
} else {
193+
let view_ref = view.as_view();
194+
view_ref
195+
.with_buffer_and_offset(
196+
buffer_lookup[view_ref.buffer_index as usize],
197+
view_ref.offset,
198+
)
199+
.into()
200+
};
201+
202+
views_builder.push(adjusted);
203+
validity_builder.append_non_null();
204+
}
205+
206+
register_kernel!(ZipKernelAdapter(VarBinViewVTable).lift());
207+
208+
#[cfg(test)]
209+
mod tests {
210+
use vortex_dtype::DType;
211+
use vortex_dtype::Nullability;
212+
use vortex_mask::Mask;
213+
214+
use crate::accessor::ArrayAccessor;
215+
use crate::arrays::VarBinViewArray;
216+
use crate::canonical::ToCanonical;
217+
use crate::compute::zip;
218+
219+
#[test]
220+
fn zip_varbinview_kernel_zips() {
221+
let a = VarBinViewArray::from_iter(
222+
[
223+
Some("aaaaaaaaaaaaa_long"), // outlined
224+
Some("short"),
225+
None,
226+
Some("bbbbbbbbbbbbbbbb_long"),
227+
Some("tiny"),
228+
Some("cccccccccccccccc_long"),
229+
],
230+
DType::Utf8(Nullability::Nullable),
231+
);
232+
233+
let b = VarBinViewArray::from_iter(
234+
[
235+
Some("dddddddddddddddd_long"),
236+
Some("eeeeeeeeeeeeeeee_long"),
237+
Some("ffff"),
238+
Some("gggggggggggggggg_long"),
239+
None,
240+
Some("hhhhhhhhhhhhhhhh_long"),
241+
],
242+
DType::Utf8(Nullability::Nullable),
243+
);
244+
245+
let mask = Mask::from_iter([true, false, true, false, false, true]);
246+
247+
let zipped = zip(a.as_ref(), b.as_ref(), &mask).unwrap().to_varbinview();
248+
249+
let values = zipped.with_iterator(|it| {
250+
it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
251+
.collect::<Vec<_>>()
252+
});
253+
254+
assert_eq!(
255+
values,
256+
vec![
257+
Some("aaaaaaaaaaaaa_long".to_string()),
258+
Some("eeeeeeeeeeeeeeee_long".to_string()),
259+
None,
260+
Some("gggggggggggggggg_long".to_string()),
261+
None,
262+
Some("cccccccccccccccc_long".to_string())
263+
]
264+
);
265+
assert_eq!(zipped.len(), mask.len());
266+
assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
267+
}
268+
}

vortex-array/src/builders/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::ArrayRef;
4343
use crate::canonical::Canonical;
4444

4545
mod lazy_null_builder;
46-
use lazy_null_builder::LazyBitBufferBuilder;
46+
pub(crate) use lazy_null_builder::LazyBitBufferBuilder;
4747

4848
mod bool;
4949
mod decimal;

vortex-array/src/builders/varbinview.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ impl DeduplicatedBuffers {
436436
}
437437

438438
/// Push a new block if not seen before. Returns the idx of the block.
439-
fn push(&mut self, block: ByteBuffer) -> u32 {
439+
pub(crate) fn push(&mut self, block: ByteBuffer) -> u32 {
440440
assert!(self.buffers.len() < u32::MAX as usize, "Too many blocks");
441441

442442
let initial_len = self.len();
@@ -452,21 +452,24 @@ impl DeduplicatedBuffers {
452452
}
453453
}
454454

455-
fn extend_from_option_slice(&mut self, buffers: &[Option<ByteBuffer>]) -> Vec<Option<u32>> {
455+
pub(crate) fn extend_from_option_slice(
456+
&mut self,
457+
buffers: &[Option<ByteBuffer>],
458+
) -> Vec<Option<u32>> {
456459
buffers
457460
.iter()
458461
.map(|buffer| buffer.as_ref().map(|buf| self.push(buf.clone())))
459462
.collect()
460463
}
461464

462-
fn extend_from_slice(&mut self, buffers: &[ByteBuffer]) -> Vec<u32> {
465+
pub(crate) fn extend_from_slice(&mut self, buffers: &[ByteBuffer]) -> Vec<u32> {
463466
buffers
464467
.iter()
465468
.map(|buffer| self.push(buffer.clone()))
466469
.collect()
467470
}
468471

469-
fn finish(self) -> Arc<[ByteBuffer]> {
472+
pub(crate) fn finish(self) -> Arc<[ByteBuffer]> {
470473
Arc::from(self.buffers)
471474
}
472475
}

vortex-array/src/compute/zip.rs

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use super::cast;
1919
use crate::Array;
2020
use crate::ArrayRef;
2121
use crate::builders::ArrayBuilder;
22-
use crate::builders::VarBinViewBuilder;
2322
use crate::builders::builder_with_capacity;
2423
use crate::compute::ComputeFn;
2524
use crate::compute::Kernel;
@@ -210,27 +209,15 @@ fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexRes
210209
);
211210

212211
let return_type = zip_return_dtype(if_true, if_false);
213-
let capacity = if_true.len();
214-
215-
let builder = match return_type {
216-
// TODO(blaginin): once https://github.com/vortex-data/vortex/pull/4695 is merged, we can kill
217-
// these two special cases, but before that we need to manually use deduplicated buffers.
218-
// Otherwise, the same buffer will be appended multiple times causing fragmentation.
219-
DType::Utf8(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
220-
DType::Utf8(n),
221-
capacity,
222-
)),
223-
DType::Binary(n) => Box::new(VarBinViewBuilder::with_buffer_deduplication(
224-
DType::Binary(n),
225-
capacity,
226-
)),
227-
_ => builder_with_capacity(&return_type, if_true.len()),
228-
};
229-
230-
zip_impl_with_builder(if_true, if_false, mask, builder)
212+
zip_impl_with_builder(
213+
if_true,
214+
if_false,
215+
mask,
216+
builder_with_capacity(&return_type, if_true.len()),
217+
)
231218
}
232219

233-
pub(crate) fn zip_impl_with_builder(
220+
fn zip_impl_with_builder(
234221
if_true: &dyn Array,
235222
if_false: &dyn Array,
236223
mask: &Mask,
@@ -272,8 +259,8 @@ mod tests {
272259
use crate::arrow::IntoArrowArray;
273260
use crate::builders::ArrayBuilder;
274261
use crate::builders::BufferGrowthStrategy;
262+
use crate::builders::VarBinViewBuilder;
275263
use crate::compute::zip;
276-
use crate::compute::zip::VarBinViewBuilder;
277264

278265
#[test]
279266
fn test_zip_basic() {

0 commit comments

Comments
 (0)