Skip to content

Commit baafc71

Browse files
authored
Filter on lists (#5775)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent c93b6a7 commit baafc71

File tree

14 files changed

+738
-101
lines changed

14 files changed

+738
-101
lines changed

vortex-array/src/arrays/dict/vtable/rules.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,45 @@ use crate::arrays::ConstantArray;
1313
use crate::arrays::ConstantVTable;
1414
use crate::arrays::DictArray;
1515
use crate::arrays::DictVTable;
16+
use crate::arrays::FilterArray;
17+
use crate::arrays::FilterVTable;
1618
use crate::arrays::ScalarFnArray;
1719
use crate::builtins::ArrayBuiltins;
1820
use crate::expr::Pack;
21+
use crate::matchers::Exact;
1922
use crate::optimizer::ArrayOptimizer;
2023
use crate::optimizer::rules::ArrayParentReduceRule;
2124
use crate::optimizer::rules::ParentRuleSet;
2225

2326
pub(super) const PARENT_RULES: ParentRuleSet<DictVTable> = ParentRuleSet::new(&[
27+
ParentRuleSet::lift(&DictionaryFilterPushDownRule),
2428
ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule),
2529
ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule),
2630
]);
2731

32+
#[derive(Debug)]
33+
struct DictionaryFilterPushDownRule;
34+
35+
impl ArrayParentReduceRule<DictVTable> for DictionaryFilterPushDownRule {
36+
type Parent = Exact<FilterVTable>;
37+
38+
fn parent(&self) -> Self::Parent {
39+
Exact::from(&FilterVTable)
40+
}
41+
42+
fn reduce_parent(
43+
&self,
44+
array: &DictArray,
45+
parent: &FilterArray,
46+
_child_idx: usize,
47+
) -> VortexResult<Option<ArrayRef>> {
48+
let new_codes = array.codes().filter(parent.filter_mask().clone())?;
49+
let new_dict =
50+
unsafe { DictArray::new_unchecked(new_codes, array.values().clone()) }.into_array();
51+
Ok(Some(new_dict))
52+
}
53+
}
54+
2855
/// Push down a scalar function to run only over the values of a dictionary array.
2956
#[derive(Debug)]
3057
struct DictionaryScalarFnValuesPushDownRule;

vortex-array/src/arrays/filter/rules.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use vortex_error::VortexResult;
55

6+
use crate::Array;
67
use crate::ArrayRef;
78
use crate::IntoArray;
89
use crate::arrays::FilterArray;
@@ -32,7 +33,7 @@ impl ArrayParentReduceRule<FilterVTable> for FilterFilterRule {
3233
_child_idx: usize,
3334
) -> VortexResult<Option<ArrayRef>> {
3435
let combined_mask = child.mask.intersect_by_rank(&parent.mask);
35-
let new_array = FilterArray::new(child.child.clone(), combined_mask);
36+
let new_array = child.child.filter(combined_mask)?;
3637
Ok(Some(new_array.into_array()))
3738
}
3839
}

vortex-array/src/arrays/list/compute/filter.rs

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
46
use vortex_buffer::BitBufferMut;
57
use vortex_buffer::BufferMut;
68
use vortex_dtype::IntegerPType;
@@ -9,19 +11,18 @@ use vortex_error::VortexExpect;
911
use vortex_error::VortexResult;
1012
use vortex_mask::Mask;
1113
use vortex_mask::MaskIter;
14+
use vortex_mask::MaskValues;
1215

1316
use crate::Array;
1417
use crate::ArrayRef;
1518
use crate::IntoArray;
1619
use crate::ToCanonical;
1720
use crate::arrays::ListArray;
1821
use crate::arrays::ListVTable;
19-
use crate::arrays::PrimitiveArray;
2022
use crate::compute::FilterKernel;
2123
use crate::compute::FilterKernelAdapter;
2224
use crate::compute::filter;
2325
use crate::register_kernel;
24-
use crate::validity::Validity;
2526
use crate::vtable::ValidityHelper;
2627

2728
/// Density threshold for choosing between indices and slices representation when expanding masks.
@@ -44,21 +45,22 @@ impl FilterKernel for ListVTable {
4445
);
4546

4647
let (new_elements, new_offsets) = match_each_integer_ptype!(offsets.ptype(), |O| {
47-
compute_filtered_elements_and_offsets::<O>(
48+
let (new_elements, new_offsets) = compute_filtered_elements_and_offsets::<O>(
4849
elements.as_ref(),
4950
offsets.as_slice::<O>(),
5051
selection_mask,
51-
)?
52+
)?;
53+
(new_elements, new_offsets.into_array())
5254
});
5355

5456
// SAFETY: Filter operation maintains all ListArray invariants:
5557
// - Offsets are monotonically increasing (built correctly above).
5658
// - Elements are properly filtered to match the offsets.
5759
// - Validity matches the original array's nullability.
58-
Ok(unsafe {
59-
ListArray::new_unchecked(new_elements, new_offsets.into_array(), new_validity)
60-
}
61-
.into_array())
60+
Ok(
61+
unsafe { ListArray::new_unchecked(new_elements, new_offsets, new_validity) }
62+
.into_array(),
63+
)
6264
}
6365
}
6466

@@ -74,7 +76,7 @@ fn compute_filtered_elements_and_offsets<O: IntegerPType>(
7476
elements: &dyn Array,
7577
offsets: &[O],
7678
selection_mask: &Mask,
77-
) -> VortexResult<(ArrayRef, PrimitiveArray)> {
79+
) -> VortexResult<(ArrayRef, BufferMut<O>)> {
7880
let values = selection_mask
7981
.values()
8082
.vortex_expect("`AllTrue` and `AllFalse` are handled by filter entry point");
@@ -132,11 +134,50 @@ fn compute_filtered_elements_and_offsets<O: IntegerPType>(
132134
// The `Mask` can determine the best representation based on the buffer's density in the future.
133135
let new_elements = filter(elements, &Mask::from_buffer(new_mask_builder.freeze()))?;
134136

135-
let new_offsets = PrimitiveArray::new(new_offsets, Validity::NonNullable);
136-
137137
Ok((new_elements, new_offsets))
138138
}
139139

140+
/// Construct an element mask from contiguous list offsets and a selection mask.
141+
pub fn element_mask_from_offsets<O: IntegerPType>(
142+
offsets: &[O],
143+
selection: &Arc<MaskValues>,
144+
) -> Mask {
145+
let first_offset = offsets.first().map_or(0, |first_offset| first_offset.as_());
146+
let last_offset = offsets.last().map_or(0, |last_offset| last_offset.as_());
147+
let len = last_offset - first_offset;
148+
149+
let mut mask_builder = BitBufferMut::with_capacity(len);
150+
151+
match selection.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) {
152+
MaskIter::Slices(slices) => {
153+
// Dense iteration: process ranges of consecutive selected lists.
154+
for &(start, end) in slices {
155+
// Optimization: for dense ranges, we can process the elements mask more efficiently.
156+
let elems_start = offsets[start].as_() - first_offset;
157+
let elems_end = offsets[end].as_() - first_offset;
158+
159+
// Process the entire range of elements at once.
160+
process_element_range(elems_start, elems_end, &mut mask_builder);
161+
}
162+
}
163+
MaskIter::Indices(indices) => {
164+
// Sparse iteration: process individual selected lists.
165+
for &idx in indices {
166+
let list_start = offsets[idx].as_() - first_offset;
167+
let list_end = offsets[idx + 1].as_() - first_offset;
168+
169+
// Process the elements for this list.
170+
process_element_range(list_start, list_end, &mut mask_builder);
171+
}
172+
}
173+
}
174+
175+
// Pad to full length if necessary.
176+
mask_builder.append_n(false, last_offset - mask_builder.len());
177+
178+
Mask::from_buffer(mask_builder.freeze())
179+
}
180+
140181
/// Process a range of elements for filtering.
141182
fn process_element_range(
142183
elems_start: usize,

vortex-array/src/arrays/list/compute/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ mod mask;
99
mod min_max;
1010
mod take;
1111

12+
pub(super) use filter::element_mask_from_offsets;
13+
1214
#[cfg(test)]
1315
mod tests {
1416
use rstest::rstest;
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::sync::Arc;
5+
6+
use vortex_buffer::BufferMut;
7+
use vortex_dtype::PTypeDowncastExt;
8+
use vortex_dtype::match_each_integer_ptype;
9+
use vortex_error::VortexResult;
10+
use vortex_mask::Mask;
11+
use vortex_vector::Vector;
12+
use vortex_vector::VectorMutOps;
13+
use vortex_vector::listview::ListViewVector;
14+
use vortex_vector::listview::ListViewVectorMut;
15+
use vortex_vector::primitive::PVector;
16+
use vortex_vector::primitive::PrimitiveVector;
17+
18+
use crate::Array;
19+
use crate::ExecutionCtx;
20+
use crate::VectorExecutor;
21+
use crate::arrays::FilterArray;
22+
use crate::arrays::FilterVTable;
23+
use crate::arrays::ListArray;
24+
use crate::arrays::ListVTable;
25+
use crate::arrays::list::compute::element_mask_from_offsets;
26+
use crate::kernel::ExecuteParentKernel;
27+
use crate::mask::MaskExecutor;
28+
use crate::matchers::Exact;
29+
use crate::validity::Validity;
30+
use crate::vtable::ValidityHelper;
31+
32+
#[derive(Debug)]
33+
pub(super) struct ListFilterKernel;
34+
35+
impl ExecuteParentKernel<ListVTable> for ListFilterKernel {
36+
type Parent = Exact<FilterVTable>;
37+
38+
fn parent(&self) -> Self::Parent {
39+
Exact::from(&FilterVTable)
40+
}
41+
42+
fn execute_parent(
43+
&self,
44+
array: &ListArray,
45+
parent: &FilterArray,
46+
_child_idx: usize,
47+
ctx: &mut ExecutionCtx,
48+
) -> VortexResult<Option<Vector>> {
49+
let selection = match parent.filter_mask() {
50+
Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None),
51+
Mask::Values(v) => v,
52+
};
53+
54+
// TODO(ngates): for ultra-sparse masks, we don't need to optimize the entire offsets.
55+
let offsets = array
56+
.offsets()
57+
.execute_vector(ctx.session())?
58+
.into_primitive();
59+
60+
let new_validity = match array.validity() {
61+
Validity::NonNullable | Validity::AllValid => Mask::new_true(selection.true_count()),
62+
Validity::AllInvalid => {
63+
let mut vec = ListViewVectorMut::with_capacity(
64+
array.elements().dtype(),
65+
selection.true_count(),
66+
);
67+
vec.append_nulls(selection.true_count());
68+
return Ok(Some(vec.freeze().into()));
69+
}
70+
Validity::Array(a) => a
71+
.filter(parent.filter_mask().clone())?
72+
.execute_mask(ctx.session())?,
73+
};
74+
75+
let (new_offsets, new_sizes) = match_each_integer_ptype!(offsets.ptype(), |O| {
76+
let offsets = (&offsets).downcast::<O>().elements().as_slice();
77+
let mut new_offsets = BufferMut::<O>::with_capacity(selection.true_count());
78+
let mut new_sizes = BufferMut::<O>::with_capacity(selection.true_count());
79+
80+
let mut offset = 0;
81+
for idx in selection.indices() {
82+
let start = offsets[*idx];
83+
let end = offsets[idx + 1];
84+
let size = end - start;
85+
unsafe { new_offsets.push_unchecked(offset) };
86+
unsafe { new_sizes.push_unchecked(size) };
87+
offset += size;
88+
}
89+
90+
let new_offsets = PrimitiveVector::from(PVector::<O>::new(
91+
new_offsets.freeze(),
92+
Mask::new_true(selection.true_count()),
93+
));
94+
let new_sizes = PrimitiveVector::from(PVector::<O>::new(
95+
new_sizes.freeze(),
96+
Mask::new_true(selection.true_count()),
97+
));
98+
99+
(new_offsets, new_sizes)
100+
});
101+
102+
// TODO(ngates): for very dense masks, there may be no point in filtering the elements,
103+
// and instead we should construct a view against the unfiltered elements.
104+
let element_mask = match_each_integer_ptype!(offsets.ptype(), |O| {
105+
element_mask_from_offsets::<O>((&offsets).downcast::<O>().elements(), selection)
106+
});
107+
108+
let new_elements = array
109+
.sliced_elements()
110+
.filter(element_mask)?
111+
.execute_vector(ctx.session())?;
112+
113+
Ok(Some(
114+
unsafe {
115+
ListViewVector::new_unchecked(
116+
Arc::new(new_elements),
117+
new_offsets,
118+
new_sizes,
119+
new_validity,
120+
)
121+
}
122+
.into(),
123+
))
124+
}
125+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod filter;
5+
6+
use crate::arrays::ListVTable;
7+
use crate::arrays::list::vtable::kernel::filter::ListFilterKernel;
8+
use crate::kernel::ParentKernelSet;
9+
10+
pub(super) const PARENT_KERNELS: ParentKernelSet<ListVTable> =
11+
ParentKernelSet::new(&[ParentKernelSet::lift(&ListFilterKernel)]);

vortex-array/src/arrays/list/vtable/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ use vortex_error::VortexExpect;
99
use vortex_error::VortexResult;
1010
use vortex_error::vortex_bail;
1111
use vortex_error::vortex_ensure;
12+
use vortex_vector::Vector;
1213

1314
use crate::ArrayRef;
15+
use crate::ExecutionCtx;
1416
use crate::ProstMetadata;
1517
use crate::arrays::ListArray;
18+
use crate::arrays::list::vtable::kernel::PARENT_KERNELS;
1619
use crate::metadata::DeserializeMetadata;
1720
use crate::metadata::SerializeMetadata;
1821
use crate::serde::ArrayChildren;
@@ -27,6 +30,7 @@ use crate::vtable::ValidityVTableFromValidityHelper;
2730

2831
mod array;
2932
mod canonical;
33+
mod kernel;
3034
mod operations;
3135
mod validity;
3236
mod visitor;
@@ -138,6 +142,15 @@ impl VTable for ListVTable {
138142
*array = new_array;
139143
Ok(())
140144
}
145+
146+
fn execute_parent(
147+
array: &Self::Array,
148+
parent: &ArrayRef,
149+
child_idx: usize,
150+
ctx: &mut ExecutionCtx,
151+
) -> VortexResult<Option<Vector>> {
152+
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
153+
}
141154
}
142155

143156
#[derive(Debug)]

vortex-array/src/arrays/listview/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::ProstMetadata;
2020
use crate::SerializeMetadata;
2121
use crate::VectorExecutor;
2222
use crate::arrays::ListViewArray;
23+
use crate::arrays::listview::vtable::rules::PARENT_RULES;
2324
use crate::executor::ExecutionCtx;
2425
use crate::serde::ArrayChildren;
2526
use crate::validity::Validity;
@@ -34,6 +35,7 @@ use crate::vtable::ValidityVTableFromValidityHelper;
3435
mod array;
3536
mod canonical;
3637
mod operations;
38+
mod rules;
3739
mod validity;
3840
mod visitor;
3941

@@ -187,4 +189,12 @@ impl VTable for ListViewVTable {
187189
}
188190
.into())
189191
}
192+
193+
fn reduce_parent(
194+
array: &Self::Array,
195+
parent: &ArrayRef,
196+
child_idx: usize,
197+
) -> VortexResult<Option<ArrayRef>> {
198+
PARENT_RULES.evaluate(array, parent, child_idx)
199+
}
190200
}

0 commit comments

Comments
 (0)