Skip to content

Commit 1c7c759

Browse files
authored
Feature: add in-place filter for BufferMut + benchmarks (#5238)
1 parent 6a12a20 commit 1c7c759

File tree

6 files changed

+310
-5
lines changed

6 files changed

+310
-5
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-compute/Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,10 @@ num-traits = { workspace = true }
3434
[features]
3535
default = ["arrow"]
3636
arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"]
37+
38+
[dev-dependencies]
39+
divan = { workspace = true }
40+
41+
[[bench]]
42+
name = "filter_buffer_mut"
43+
harness = false
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! In-place filter benchmarks for `BufferMut`.
5+
6+
use divan::Bencher;
7+
use vortex_buffer::BufferMut;
8+
use vortex_compute::filter::Filter;
9+
use vortex_mask::Mask;
10+
11+
fn main() {
12+
divan::main();
13+
}
14+
15+
const BUFFER_SIZE: usize = 1024;
16+
17+
const SELECTIVITIES: &[f64] = &[
18+
0.01, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.99,
19+
];
20+
21+
fn create_test_buffer<T>(size: usize) -> BufferMut<T>
22+
where
23+
T: Copy + Default + From<u8>,
24+
{
25+
let mut buffer = BufferMut::with_capacity(size);
26+
for i in 0..size {
27+
#[expect(clippy::cast_possible_truncation)]
28+
buffer.push(T::from((i % 256) as u8));
29+
}
30+
buffer
31+
}
32+
33+
fn generate_mask(len: usize, selectivity: f64) -> Mask {
34+
#[expect(clippy::cast_possible_truncation)]
35+
#[expect(clippy::cast_sign_loss)]
36+
let num_selected = ((len as f64) * selectivity).round() as usize;
37+
38+
let mut selection = vec![false; len];
39+
let mut indices: Vec<usize> = (0..len).collect();
40+
41+
// Simple deterministic shuffle.
42+
for i in (1..len).rev() {
43+
let j = (i * 7 + 13) % (i + 1);
44+
indices.swap(i, j);
45+
}
46+
47+
for i in 0..num_selected.min(len) {
48+
selection[indices[i]] = true;
49+
}
50+
51+
Mask::from_iter(selection)
52+
}
53+
54+
#[derive(Copy, Clone, Default)]
55+
#[allow(dead_code)]
56+
struct LargeElement([u8; 32]);
57+
58+
impl From<u8> for LargeElement {
59+
fn from(value: u8) -> Self {
60+
LargeElement([value; 32])
61+
}
62+
}
63+
64+
#[divan::bench(types = [u8, u32, u64, LargeElement], args = SELECTIVITIES, sample_count = 1000)]
65+
fn filter_selectivity<T: Copy + Default + From<u8>>(bencher: Bencher, selectivity: f64) {
66+
let mask = generate_mask(BUFFER_SIZE, selectivity);
67+
bencher
68+
.with_inputs(|| {
69+
let buffer = create_test_buffer::<T>(BUFFER_SIZE);
70+
(buffer, mask.clone())
71+
})
72+
.bench_values(|(mut buffer, mask)| {
73+
buffer.filter(&mask);
74+
divan::black_box(buffer);
75+
});
76+
}

vortex-compute/src/filter/buffer.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,48 @@ const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
1212
impl<T: Copy> Filter for &Buffer<T> {
1313
type Output = Buffer<T>;
1414

15-
fn filter(self, mask: &Mask) -> Buffer<T> {
16-
assert_eq!(mask.len(), self.len());
17-
match mask {
15+
fn filter(self, selection_mask: &Mask) -> Buffer<T> {
16+
assert_eq!(
17+
selection_mask.len(),
18+
self.len(),
19+
"Selection mask length must equal the buffer length"
20+
);
21+
22+
match selection_mask {
1823
Mask::AllTrue(_) => self.clone(),
1924
Mask::AllFalse(_) => Buffer::empty(),
2025
Mask::Values(v) => match v.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) {
2126
MaskIter::Indices(indices) => filter_indices(self.as_slice(), indices),
2227
MaskIter::Slices(slices) => {
23-
filter_slices(self.as_slice(), mask.true_count(), slices)
28+
filter_slices(self.as_slice(), selection_mask.true_count(), slices)
2429
}
2530
},
2631
}
2732
}
2833
}
2934

35+
impl<T: Copy> Filter for Buffer<T> {
36+
type Output = Self;
37+
38+
fn filter(self, selection_mask: &Mask) -> Self {
39+
assert_eq!(
40+
selection_mask.len(),
41+
self.len(),
42+
"Selection mask length must equal the buffer length"
43+
);
44+
45+
// If we have exclusive access, we can perform the filter in place.
46+
match self.try_into_mut() {
47+
Ok(mut buffer_mut) => {
48+
(&mut buffer_mut).filter(selection_mask);
49+
buffer_mut.freeze()
50+
}
51+
// Otherwise, allocate a new buffer and fill it in (delegate to the `&Buffer` impl).
52+
Err(buffer) => (&buffer).filter(selection_mask),
53+
}
54+
}
55+
}
56+
3057
fn filter_indices<T: Copy>(values: &[T], indices: &[usize]) -> Buffer<T> {
3158
Buffer::<T>::from_trusted_len_iter(indices.iter().map(|&idx| values[idx]))
3259
}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use core::ptr;
5+
6+
use vortex_buffer::BufferMut;
7+
use vortex_mask::Mask;
8+
9+
use crate::filter::Filter;
10+
11+
impl<T: Copy> Filter for &mut BufferMut<T> {
12+
type Output = ();
13+
14+
fn filter(self, selection_mask: &Mask) {
15+
assert_eq!(
16+
selection_mask.len(),
17+
self.len(),
18+
"Selection mask length must equal the buffer length"
19+
);
20+
21+
match selection_mask {
22+
Mask::AllTrue(_) => {}
23+
Mask::AllFalse(_) => self.clear(),
24+
Mask::Values(values) => {
25+
// We choose to _always_ use slices here because iterating over indices will have
26+
// strictly more loop iterations than slices, and the overhead over batched
27+
// `ptr::copy(len)` is not worth it.
28+
let slices = values.slices();
29+
30+
// SAFETY: We checked above that the selection mask has the same length as the
31+
// buffer.
32+
let new_len = unsafe { filter_slices_in_place(self.as_mut_slice(), slices) };
33+
34+
debug_assert!(
35+
new_len <= self.len(),
36+
"The new length was somehow larger after filter"
37+
);
38+
39+
// Truncate the buffer to the new length.
40+
// SAFETY: The new length cannot be larger than the old length, so all values must
41+
// be initialized.
42+
unsafe { self.set_len(new_len) };
43+
}
44+
}
45+
}
46+
}
47+
48+
/// Filters a buffer in-place using slice ranges to determine which values to keep.
49+
///
50+
/// Returns the new length of the buffer.
51+
///
52+
/// # Safety
53+
///
54+
/// The slice ranges must be in the range of the `buffer`.
55+
#[must_use = "The caller should set the new length of the buffer"]
56+
unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut [T], slices: &[(usize, usize)]) -> usize {
57+
let mut write_pos = 0;
58+
59+
// For each range in the selection, copy all of the elements to the current write position.
60+
for &(start, end) in slices {
61+
// Note that we could add an if statement here that checks `if read_idx != write_idx`, but
62+
// it's probably better to just avoid the branch misprediction.
63+
64+
let len = end - start;
65+
66+
// SAFETY: The safety contract enforces that all ranges are within bounds.
67+
unsafe {
68+
ptr::copy(
69+
buffer.as_ptr().add(start),
70+
buffer.as_mut_ptr().add(write_pos),
71+
len,
72+
)
73+
};
74+
75+
write_pos += len;
76+
}
77+
78+
write_pos
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
use vortex_buffer::{BufferMut, buffer_mut};
84+
use vortex_mask::Mask;
85+
86+
use super::*;
87+
88+
#[test]
89+
fn test_filter_all_true() {
90+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
91+
let mask = Mask::new_true(5);
92+
93+
buf.filter(&mask);
94+
assert_eq!(buf.as_slice(), &[1, 2, 3, 4, 5]);
95+
}
96+
97+
#[test]
98+
fn test_filter_all_false() {
99+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
100+
let mask = Mask::new_false(5);
101+
102+
buf.filter(&mask);
103+
assert!(buf.is_empty());
104+
}
105+
106+
#[test]
107+
fn test_filter_sparse() {
108+
let mut buf = buffer_mut![10u32, 20, 30, 40, 50];
109+
// Select indices 0, 2, 4 (sparse selection).
110+
let mask = Mask::from_iter([true, false, true, false, true]);
111+
112+
buf.filter(&mask);
113+
assert_eq!(buf.as_slice(), &[10, 30, 50]);
114+
}
115+
116+
#[test]
117+
fn test_filter_dense() {
118+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10];
119+
// Dense selection (80% selected).
120+
let mask = Mask::from_iter([true, true, true, true, false, true, true, true, false, true]);
121+
122+
buf.filter(&mask);
123+
assert_eq!(buf.as_slice(), &[1, 2, 3, 4, 6, 7, 8, 10]);
124+
}
125+
126+
#[test]
127+
fn test_filter_single_element_kept() {
128+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
129+
let mask = Mask::from_iter([false, false, true, false, false]);
130+
131+
buf.filter(&mask);
132+
assert_eq!(buf.as_slice(), &[3]);
133+
}
134+
135+
#[test]
136+
fn test_filter_first_last() {
137+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
138+
let mask = Mask::from_iter([true, false, false, false, true]);
139+
140+
buf.filter(&mask);
141+
assert_eq!(buf.as_slice(), &[1, 5]);
142+
}
143+
144+
#[test]
145+
fn test_filter_alternating() {
146+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6];
147+
let mask = Mask::from_iter([true, false, true, false, true, false]);
148+
149+
buf.filter(&mask);
150+
assert_eq!(buf.as_slice(), &[1, 3, 5]);
151+
}
152+
153+
#[test]
154+
fn test_filter_empty_buffer() {
155+
let mut buf: BufferMut<u32> = BufferMut::with_capacity(0);
156+
let mask = Mask::new_false(0);
157+
158+
buf.filter(&mask);
159+
assert!(buf.is_empty());
160+
}
161+
162+
#[test]
163+
fn test_filter_contiguous_regions() {
164+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10];
165+
// Two contiguous regions: [0..3] and [7..10].
166+
let mask = Mask::from_iter([
167+
true, true, true, false, false, false, false, true, true, true,
168+
]);
169+
170+
buf.filter(&mask);
171+
assert_eq!(buf.as_slice(), &[1, 2, 3, 8, 9, 10]);
172+
}
173+
174+
#[test]
175+
fn test_filter_large_buffer() {
176+
let mut buf: BufferMut<u32> = BufferMut::from_iter(0..1000);
177+
// Keep every third element.
178+
let mask = Mask::from_iter((0..1000).map(|i| i % 3 == 0));
179+
180+
buf.filter(&mask);
181+
let expected: Vec<u32> = (0..1000).filter(|i| i % 3 == 0).collect();
182+
assert_eq!(buf.as_slice(), &expected[..]);
183+
}
184+
185+
#[test]
186+
#[should_panic(expected = "Selection mask length must equal the buffer length")]
187+
fn test_filter_length_mismatch() {
188+
let mut buf = buffer_mut![1u32, 2, 3];
189+
let mask = Mask::new_true(5); // Wrong length.
190+
191+
buf.filter(&mask);
192+
}
193+
}

vortex-compute/src/filter/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
mod bitbuffer;
77
mod bool;
88
mod buffer;
9+
mod buffer_mut;
910
mod mask;
1011

1112
use vortex_mask::Mask;
@@ -22,5 +23,5 @@ pub trait Filter {
2223
/// # Panics
2324
///
2425
/// If the length of the mask does not equal the length of the value being filtered.
25-
fn filter(self, mask: &Mask) -> Self::Output;
26+
fn filter(self, selection_mask: &Mask) -> Self::Output;
2627
}

0 commit comments

Comments
 (0)