Skip to content

Commit 8f0ba91

Browse files
authored
feat: cache FilterMask iterators (#1351)
This PR uses a cute trick to pre-cache filter indices when a FilterMask is cloned. Unfortunately... we have to clone one every time we pass it into with_dyn!!
1 parent fe50481 commit 8f0ba91

File tree

24 files changed

+352
-192
lines changed

24 files changed

+352
-192
lines changed

encodings/alp/src/alp/compute.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ impl SliceFn for ALPArray {
8686
}
8787

8888
impl FilterFn for ALPArray {
89-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
89+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
9090
Ok(Self::try_new(
91-
filter(&self.encoded(), mask)?,
91+
filter(&self.encoded(), mask.clone())?,
9292
self.exponents(),
9393
self.patches().map(|p| filter(&p, mask)).transpose()?,
9494
)?

encodings/alp/src/alp_rd/compute/filter.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ use vortex_error::VortexResult;
55
use crate::ALPRDArray;
66

77
impl FilterFn for ALPRDArray {
8-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
8+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
99
let left_parts_exceptions = self
1010
.left_parts_exceptions()
11-
.map(|array| filter(&array, mask))
11+
.map(|array| filter(&array, mask.clone()))
1212
.transpose()?;
1313

1414
Ok(ALPRDArray::try_new(
1515
self.dtype().clone(),
16-
filter(&self.left_parts(), mask)?,
16+
filter(&self.left_parts(), mask.clone())?,
1717
self.left_parts_dict(),
1818
filter(&self.right_parts(), mask)?,
1919
self.right_bit_width(),
@@ -43,13 +43,10 @@ mod test {
4343
assert!(encoded.left_parts_exceptions().is_some());
4444

4545
// The first two values need no patching
46-
let filtered = filter(
47-
encoded.as_ref(),
48-
&FilterMask::from_iter([true, false, true]),
49-
)
50-
.unwrap()
51-
.into_primitive()
52-
.unwrap();
46+
let filtered = filter(encoded.as_ref(), FilterMask::from_iter([true, false, true]))
47+
.unwrap()
48+
.into_primitive()
49+
.unwrap();
5350
assert_eq!(filtered.maybe_null_slice::<T>(), &[a, outlier]);
5451
}
5552
}

encodings/dict/src/compute.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl TakeFn for DictArray {
8686
}
8787

8888
impl FilterFn for DictArray {
89-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
89+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
9090
let codes = filter(&self.codes(), mask)?;
9191
Self::try_new(codes, self.values()).map(|a| a.into_array())
9292
}

encodings/fastlanes/src/for/compute.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl TakeFn for FoRArray {
4848
}
4949

5050
impl FilterFn for FoRArray {
51-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
51+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
5252
Self::try_new(
5353
filter(&self.encoded(), mask)?,
5454
self.owned_reference_scalar(),

encodings/fsst/src/compute.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ impl ScalarAtFn for FSSTArray {
151151

152152
impl FilterFn for FSSTArray {
153153
// Filtering an FSSTArray filters the codes array, leaving the symbols array untouched
154-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
154+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
155155
Ok(Self::try_new(
156156
self.dtype().clone(),
157157
self.symbols(),
158158
self.symbol_lengths(),
159-
filter(&self.codes(), mask)?,
159+
filter(&self.codes(), mask.clone())?,
160160
filter(&self.uncompressed_lengths(), mask)?,
161161
)?
162162
.into_array())

encodings/fsst/tests/fsst_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ fn test_fsst_array_ops() {
8787
// test filter
8888
let mask = FilterMask::from_iter([false, true, false]);
8989

90-
let fsst_filtered = filter(&fsst_array, &mask).unwrap();
90+
let fsst_filtered = filter(&fsst_array, mask).unwrap();
9191
assert_eq!(fsst_filtered.encoding().id(), FSST::ENCODING.id());
9292
assert_eq!(fsst_filtered.len(), 1);
9393
assert_nth_scalar!(

encodings/runend/src/compute.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ impl ScalarAtFn for RunEndArray {
7777
}
7878

7979
impl TakeFn for RunEndArray {
80+
#[allow(deprecated)]
8081
fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult<ArrayData> {
8182
let primitive_indices = indices.clone().into_primitive()?;
8283
let u64_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
@@ -111,7 +112,6 @@ impl TakeFn for RunEndArray {
111112
Validity::Array(original_validity) => {
112113
let dense_validity =
113114
FilterMask::try_from(take(&original_validity, indices, options)?)?;
114-
let filtered_values = filter(&dense_values, &dense_validity)?;
115115
let length = dense_validity.len();
116116
let dense_nonnull_indices = PrimitiveArray::from(
117117
dense_validity
@@ -120,6 +120,7 @@ impl TakeFn for RunEndArray {
120120
.collect::<Vec<_>>(),
121121
)
122122
.into_array();
123+
let filtered_values = filter(&dense_values, dense_validity)?;
123124

124125
SparseArray::try_new(
125126
dense_nonnull_indices,
@@ -150,13 +151,13 @@ impl SliceFn for RunEndArray {
150151
}
151152

152153
impl FilterFn for RunEndArray {
153-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
154+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
154155
let primitive_run_ends = self.ends().into_primitive()?;
155156
let (run_ends, mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
156157
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), mask)?
157158
});
158-
let values = filter(&self.values(), &mask)?;
159159
let validity = self.validity().filter(&mask)?;
160+
let values = filter(&self.values(), mask)?;
160161

161162
RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array())
162163
}
@@ -165,7 +166,7 @@ impl FilterFn for RunEndArray {
165166
// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425
166167
fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
167168
run_ends: &[R],
168-
mask: &FilterMask,
169+
mask: FilterMask,
169170
) -> VortexResult<(PrimitiveArray, FilterMask)> {
170171
let mut new_run_ends = vec![R::zero(); run_ends.len()];
171172

@@ -444,7 +445,7 @@ mod test {
444445
let arr = ree_array();
445446
let filtered = filter(
446447
arr.as_ref(),
447-
&FilterMask::from_iter([
448+
FilterMask::from_iter([
448449
true, true, false, false, false, false, false, false, false, false, true, true,
449450
]),
450451
)

fuzz/fuzz_targets/array_ops.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
5858
assert_search_sorted(sorted, s, side, expected.search(), i)
5959
}
6060
Action::Filter(mask) => {
61-
current_array = filter(&current_array, &mask).unwrap();
61+
current_array = filter(&current_array, mask).unwrap();
6262
assert_array_eq(&expected.array(), &current_array, i);
6363
}
6464
}

pyvortex/src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ impl PyArray {
268268
fn filter(&self, filter: &Bound<PyArray>) -> PyResult<PyArray> {
269269
let filter = filter.borrow();
270270
let inner =
271-
vortex::compute::filter(&self.inner, &FilterMask::try_from(filter.inner.clone())?)?;
271+
vortex::compute::filter(&self.inner, FilterMask::try_from(filter.inner.clone())?)?;
272272
Ok(PyArray { inner })
273273
}
274274

vortex-array/src/array/bool/compute/filter.rs

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,93 @@
1-
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
2-
use vortex_error::VortexResult;
1+
use arrow_buffer::{bit_util, BooleanBuffer, BooleanBufferBuilder};
2+
use vortex_error::{VortexExpect, VortexResult};
33

44
use crate::array::BoolArray;
5-
use crate::compute::{FilterFn, FilterMask};
5+
use crate::compute::{FilterFn, FilterIter, FilterMask};
66
use crate::{ArrayData, IntoArrayData};
77

88
impl FilterFn for BoolArray {
9-
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
10-
filter_select_bool(self, mask).map(|a| a.into_array())
9+
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
10+
let validity = self.validity().filter(&mask)?;
11+
12+
let buffer = match mask.iter()? {
13+
FilterIter::Indices(indices) => filter_indices_slice(&self.boolean_buffer(), indices),
14+
FilterIter::IndicesIter(iter) => {
15+
filter_indices(&self.boolean_buffer(), mask.true_count(), iter)
16+
}
17+
FilterIter::Slices(slices) => filter_slices(
18+
&self.boolean_buffer(),
19+
mask.true_count(),
20+
slices.iter().copied(),
21+
),
22+
FilterIter::SlicesIter(iter) => {
23+
filter_slices(&self.boolean_buffer(), mask.true_count(), iter)
24+
}
25+
};
26+
27+
Ok(Self::try_new(buffer, validity)?.into_array())
1128
}
1229
}
1330

14-
fn filter_select_bool(arr: &BoolArray, mask: &FilterMask) -> VortexResult<BoolArray> {
15-
let validity = arr.validity().filter(mask)?;
16-
17-
let selection_count = mask.true_count();
18-
let out = if selection_count * 2 > arr.len() {
19-
filter_select_bool_by_slice(&arr.boolean_buffer(), mask, selection_count)
20-
} else {
21-
filter_select_bool_by_index(&arr.boolean_buffer(), mask, selection_count)
22-
};
23-
BoolArray::try_new(out?, validity)
31+
/// Select indices from a boolean buffer.
32+
/// NOTE: it was benchmarked to be faster using collect_bool to index into a slice than to
33+
/// pass the indices as an iterator of usize. So we keep this alternate implementation.
34+
fn filter_indices_slice(buffer: &BooleanBuffer, indices: &[usize]) -> BooleanBuffer {
35+
let src = buffer.values().as_ptr();
36+
let offset = buffer.offset();
37+
BooleanBuffer::collect_bool(indices.len(), |idx| unsafe {
38+
bit_util::get_bit_raw(src, *indices.get_unchecked(idx) + offset)
39+
})
2440
}
2541

26-
fn filter_select_bool_by_slice(
27-
values: &BooleanBuffer,
28-
mask: &FilterMask,
29-
selection_count: usize,
30-
) -> VortexResult<BooleanBuffer> {
31-
let mut out_buf = BooleanBufferBuilder::new(selection_count);
32-
mask.iter_slices()?.for_each(|(start, end)| {
33-
out_buf.append_buffer(&values.slice(start, end - start));
34-
});
35-
Ok(out_buf.finish())
42+
pub fn filter_indices(
43+
buffer: &BooleanBuffer,
44+
indices_len: usize,
45+
mut indices: impl Iterator<Item = usize>,
46+
) -> BooleanBuffer {
47+
let src = buffer.values().as_ptr();
48+
let offset = buffer.offset();
49+
50+
BooleanBuffer::collect_bool(indices_len, |_idx| {
51+
let idx = indices
52+
.next()
53+
.vortex_expect("iterator is guaranteed to be within the length of the array.");
54+
unsafe { bit_util::get_bit_raw(src, idx + offset) }
55+
})
3656
}
3757

38-
fn filter_select_bool_by_index(
39-
values: &BooleanBuffer,
40-
mask: &FilterMask,
41-
selection_count: usize,
42-
) -> VortexResult<BooleanBuffer> {
43-
let mut out_buf = BooleanBufferBuilder::new(selection_count);
44-
mask.iter_indices()?
45-
.for_each(|idx| out_buf.append(values.value(idx)));
46-
Ok(out_buf.finish())
58+
pub fn filter_slices(
59+
buffer: &BooleanBuffer,
60+
indices_len: usize,
61+
slices: impl Iterator<Item = (usize, usize)>,
62+
) -> BooleanBuffer {
63+
let src = buffer.values();
64+
let offset = buffer.offset();
65+
66+
let mut builder = BooleanBufferBuilder::new(indices_len);
67+
for (start, end) in slices {
68+
builder.append_packed_range(start + offset..end + offset, src)
69+
}
70+
builder.into()
4771
}
4872

4973
#[cfg(test)]
5074
mod test {
5175
use itertools::Itertools;
5276

53-
use crate::array::bool::compute::filter::{
54-
filter_select_bool, filter_select_bool_by_index, filter_select_bool_by_slice,
55-
};
77+
use crate::array::bool::compute::filter::{filter_indices, filter_slices};
5678
use crate::array::BoolArray;
57-
use crate::compute::FilterMask;
79+
use crate::compute::{filter, FilterMask};
80+
use crate::{IntoArrayData, IntoArrayVariant};
5881

5982
#[test]
6083
fn filter_bool_test() {
6184
let arr = BoolArray::from_iter([true, true, false]);
6285
let mask = FilterMask::from_iter([true, false, true]);
6386

64-
let filtered = filter_select_bool(&arr, &mask).unwrap();
87+
let filtered = filter(&arr.into_array(), mask)
88+
.unwrap()
89+
.into_bool()
90+
.unwrap();
6591
assert_eq!(2, filtered.len());
6692

6793
assert_eq!(
@@ -73,9 +99,8 @@ mod test {
7399
#[test]
74100
fn filter_bool_by_slice_test() {
75101
let arr = BoolArray::from_iter([true, true, false]);
76-
let mask = FilterMask::from_iter([true, false, true]);
77102

78-
let filtered = filter_select_bool_by_slice(&arr.boolean_buffer(), &mask, 2).unwrap();
103+
let filtered = filter_slices(&arr.boolean_buffer(), 2, [(0, 1), (2, 3)].into_iter());
79104
assert_eq!(2, filtered.len());
80105

81106
assert_eq!(vec![true, false], filtered.iter().collect_vec())
@@ -84,9 +109,8 @@ mod test {
84109
#[test]
85110
fn filter_bool_by_index_test() {
86111
let arr = BoolArray::from_iter([true, true, false]);
87-
let mask = FilterMask::from_iter([true, false, true]);
88112

89-
let filtered = filter_select_bool_by_index(&arr.boolean_buffer(), &mask, 2).unwrap();
113+
let filtered = filter_indices(&arr.boolean_buffer(), 2, [0, 2].into_iter());
90114
assert_eq!(2, filtered.len());
91115

92116
assert_eq!(vec![true, false], filtered.iter().collect_vec())

0 commit comments

Comments
 (0)