Skip to content

Commit e1a4298

Browse files
authored
fix[vortex-array]: fix take on varbinviews with NULL indices (#5626)
1 parent 4e98558 commit e1a4298

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

vortex-array/src/arrays/varbinview/compute/take.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::iter;
45
use std::ops::Deref;
56

67
use num_traits::AsPrimitive;
78
use vortex_buffer::Buffer;
89
use vortex_dtype::match_each_integer_ptype;
910
use vortex_error::VortexResult;
11+
use vortex_mask::AllOr;
12+
use vortex_mask::Mask;
1013
use vortex_vector::binaryview::BinaryView;
1114

1215
use crate::Array;
@@ -23,16 +26,16 @@ use crate::vtable::ValidityHelper;
2326
/// Take involves creating a new array that references the old array, just with the given set of views.
2427
impl TakeKernel for VarBinViewVTable {
2528
fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
26-
// Compute the new validity
27-
28-
// This is valid since all elements (of all arrays) even null values must be inside
29-
// min-max valid range.
29+
// Compute the new validity.
3030
let validity = array.validity().take(indices)?;
3131
let indices = indices.to_primitive();
3232

3333
let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
34-
// This is valid since all elements even null values are inside the min-max valid range.
35-
take_views(array.views(), indices.as_slice::<I>())
34+
take_views(
35+
array.views(),
36+
indices.as_slice::<I>(),
37+
&indices.validity_mask(),
38+
)
3639
});
3740

3841
// SAFETY: taking all components at same indices maintains invariants
@@ -55,15 +58,36 @@ register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
5558
fn take_views<I: AsPrimitive<usize>>(
5659
views: &Buffer<BinaryView>,
5760
indices: &[I],
61+
mask: &Mask,
5862
) -> Buffer<BinaryView> {
5963
// NOTE(ngates): this deref is not actually trivial, so we run it once.
6064
let views_ref = views.deref();
61-
Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
65+
// We do not use iter_bools directly, since the resulting dyn iterator cannot
66+
// implement TrustedLen.
67+
match mask.bit_buffer() {
68+
AllOr::All => {
69+
Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
70+
}
71+
AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
72+
BinaryView::default(),
73+
indices.len(),
74+
)),
75+
AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
76+
buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
77+
if valid {
78+
views_ref[idx.as_()]
79+
} else {
80+
BinaryView::default()
81+
}
82+
}),
83+
),
84+
}
6285
}
6386

6487
#[cfg(test)]
6588
mod tests {
6689
use rstest::rstest;
90+
use vortex_buffer::BitBuffer;
6791
use vortex_buffer::buffer;
6892
use vortex_dtype::DType;
6993
use vortex_dtype::Nullability::NonNullable;
@@ -76,6 +100,7 @@ mod tests {
76100
use crate::canonical::ToCanonical;
77101
use crate::compute::conformance::take::test_take_conformance;
78102
use crate::compute::take;
103+
use crate::validity::Validity;
79104

80105
#[test]
81106
fn take_nullable() {
@@ -103,11 +128,13 @@ mod tests {
103128
fn take_nullable_indices() {
104129
let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
105130

106-
let taken = take(
107-
arr.as_ref(),
108-
PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
109-
)
110-
.unwrap();
131+
let indices = PrimitiveArray::new(
132+
// Verify that garbage values at NULL indices are ignored.
133+
buffer![1u64, 999],
134+
Validity::from(BitBuffer::from(vec![true, false])),
135+
);
136+
137+
let taken = take(arr.as_ref(), indices.as_ref()).unwrap();
111138

112139
assert!(taken.dtype().is_nullable());
113140
assert_eq!(

vortex-buffer/src/trusted_len.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ where
161161
{
162162
}
163163

164+
unsafe impl<T: Clone> TrustedLen for std::iter::RepeatN<T> {}
165+
164166
// Arrow bit iterators
165167
unsafe impl<'a> TrustedLen for crate::bit::BitIterator<'a> {}
166168
unsafe impl<'a> TrustedLen for crate::bit::BitChunkIterator<'a> {}

0 commit comments

Comments
 (0)