Skip to content

Commit 8eb7e58

Browse files
authored
Fix: ListArray take incorrect validity (#5762)
Fixes #5743 Bug was here: ```rust indices_array .validity() .clone() .and(array.validity().clone()), ``` This one is sort of a yikes (it might originally been my code? not sure 😱), so I did some refactoring to make writing the obviously correct code more easily Signed-off-by: Connor Tsui <[email protected]>
1 parent dcaa135 commit 8eb7e58

File tree

1 file changed

+49
-45
lines changed
  • vortex-array/src/arrays/list/compute

1 file changed

+49
-45
lines changed

vortex-array/src/arrays/list/compute/take.rs

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_buffer::BitBufferMut;
54
use vortex_dtype::IntegerPType;
65
use vortex_dtype::Nullability;
76
use vortex_dtype::match_each_integer_ptype;
87
use vortex_dtype::match_smallest_offset_type;
98
use vortex_error::VortexExpect;
109
use vortex_error::VortexResult;
11-
use vortex_mask::Mask;
1210

1311
use crate::Array;
1412
use crate::ArrayRef;
@@ -22,7 +20,6 @@ use crate::compute::TakeKernel;
2220
use crate::compute::TakeKernelAdapter;
2321
use crate::compute::take;
2422
use crate::register_kernel;
25-
use crate::validity::Validity;
2623
use crate::vtable::ValidityHelper;
2724

2825
// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
@@ -37,21 +34,13 @@ impl TakeKernel for ListVTable {
3734
#[expect(clippy::cognitive_complexity)]
3835
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3936
let indices = indices.to_primitive();
40-
let offsets = array.offsets().to_primitive();
4137
// This is an over-approximation of the total number of elements in the resulting array.
4238
let total_approx = array.elements().len().saturating_mul(indices.len());
4339

44-
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
45-
let offsets_slice = offsets.as_slice::<O>();
40+
match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
4641
match_each_integer_ptype!(indices.ptype(), |I| {
4742
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
48-
_take::<I, O, OutputOffsetType>(
49-
array,
50-
offsets_slice,
51-
&indices,
52-
array.validity_mask(),
53-
indices.validity_mask(),
54-
)
43+
_take::<I, O, OutputOffsetType>(array, &indices)
5544
})
5645
})
5746
})
@@ -62,23 +51,19 @@ register_kernel!(TakeKernelAdapter(ListVTable).lift());
6251

6352
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
6453
array: &ListArray,
65-
offsets: &[O],
6654
indices_array: &PrimitiveArray,
67-
data_validity: Mask,
68-
indices_validity_mask: Mask,
6955
) -> VortexResult<ArrayRef> {
70-
let indices: &[I] = indices_array.as_slice::<I>();
71-
72-
if !indices_validity_mask.all_true() || !data_validity.all_true() {
73-
return _take_nullable::<I, O, OutputOffsetType>(
74-
array,
75-
offsets,
76-
indices,
77-
data_validity,
78-
indices_validity_mask,
79-
);
56+
let data_validity = array.validity_mask();
57+
let indices_validity = indices_array.validity_mask();
58+
59+
if !indices_validity.all_true() || !data_validity.all_true() {
60+
return _take_nullable::<I, O, OutputOffsetType>(array, indices_array);
8061
}
8162

63+
let offsets_array = array.offsets().to_primitive();
64+
let offsets: &[O] = offsets_array.as_slice();
65+
let indices: &[I] = indices_array.as_slice();
66+
8267
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
8368
Nullability::NonNullable,
8469
indices.len(),
@@ -120,21 +105,21 @@ fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
120105
Ok(ListArray::try_new(
121106
new_elements,
122107
new_offsets,
123-
indices_array
124-
.validity()
125-
.clone()
126-
.and(array.validity().clone()),
108+
array.validity().clone().take(indices_array.as_ref())?,
127109
)?
128110
.to_array())
129111
}
130112

131113
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
132114
array: &ListArray,
133-
offsets: &[O],
134-
indices: &[I],
135-
data_validity: Mask,
136-
indices_validity: Mask,
115+
indices_array: &PrimitiveArray,
137116
) -> VortexResult<ArrayRef> {
117+
let offsets_array = array.offsets().to_primitive();
118+
let offsets: &[O] = offsets_array.as_slice();
119+
let indices: &[I] = indices_array.as_slice();
120+
let data_validity = array.validity_mask();
121+
let indices_validity = indices_array.validity_mask();
122+
138123
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
139124
Nullability::NonNullable,
140125
indices.len(),
@@ -153,28 +138,23 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy
153138
let mut current_offset = OutputOffsetType::zero();
154139
new_offsets.append_zero();
155140

156-
// Set all bits to invalid and selectively set which values are valid.
157-
let mut new_validity = BitBufferMut::new_unset(indices.len());
158-
159141
for (idx, data_idx) in indices.iter().enumerate() {
160142
if !indices_validity.value(idx) {
161143
new_offsets.append_value(current_offset);
162-
// Bit buffer already has this set to invalid.
163144
continue;
164145
}
165146

166147
let data_idx: usize = data_idx.as_();
167148

168149
if !data_validity.value(data_idx) {
169150
new_offsets.append_value(current_offset);
170-
// Bit buffer already has this set to invalid.
171151
continue;
172152
}
173153

174154
let start = offsets[data_idx];
175155
let stop = offsets[data_idx + 1];
176156

177-
// See the note it the `take` on the reasoning
157+
// See the note in `_take` on the reasoning.
178158
let additional: usize = (stop - start).as_();
179159

180160
elements_to_take.reserve_exact(additional);
@@ -184,17 +164,18 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy
184164
current_offset +=
185165
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
186166
new_offsets.append_value(current_offset);
187-
new_validity.set(idx);
188167
}
189168

190169
let elements_to_take = elements_to_take.finish();
191170
let new_offsets = new_offsets.finish();
192171
let new_elements = take(array.elements(), elements_to_take.as_ref())?;
193172

194-
let new_validity = Validity::from(new_validity.freeze());
195-
// data are indexes are nullable, so the final result is also nullable.
196-
197-
Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
173+
Ok(ListArray::try_new(
174+
new_elements,
175+
new_offsets,
176+
array.validity().clone().take(indices_array.as_ref())?,
177+
)?
178+
.to_array())
198179
}
199180

200181
#[cfg(test)]
@@ -460,4 +441,27 @@ mod test {
460441
assert!(result_view.is_invalid(1));
461442
assert!(result_view.is_valid(2));
462443
}
444+
445+
/// Regression test for validity length mismatch bug.
446+
///
447+
/// When source array has `Validity::Array(...)` and indices are non-nullable,
448+
/// the result validity must have length equal to indices.len(), not source.len().
449+
#[test]
450+
fn test_take_validity_length_mismatch_regression() {
451+
// Source array with explicit validity array (length 2).
452+
let list = ListArray::try_new(
453+
buffer![1i32, 2, 3, 4].into_array(),
454+
buffer![0, 2, 4].into_array(),
455+
Validity::Array(BoolArray::from_iter(vec![true, true]).to_array()),
456+
)
457+
.unwrap()
458+
.to_array();
459+
460+
// Take more indices than source length (4 vs 2) with non-nullable indices.
461+
let idx = buffer![0u32, 1, 0, 1].into_array();
462+
463+
// This should not panic - result should have length 4.
464+
let result = take(&list, &idx).unwrap();
465+
assert_eq!(result.len(), 4);
466+
}
463467
}

0 commit comments

Comments
 (0)