Skip to content

Commit 104ccad

Browse files
committed
move filter files around
Signed-off-by: Connor Tsui <[email protected]>
1 parent f7dcd01 commit 104ccad

File tree

9 files changed

+274
-205
lines changed

9 files changed

+274
-205
lines changed

vortex-compute/src/filter/bitbuffer.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,21 @@ const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8;
1313
impl Filter for &BitBuffer {
1414
type Output = BitBuffer;
1515

16-
fn filter(self, mask: &Mask) -> BitBuffer {
17-
assert_eq!(mask.len(), self.len());
18-
match mask {
16+
fn filter(self, selection_mask: &Mask) -> BitBuffer {
17+
assert_eq!(
18+
selection_mask.len(),
19+
self.len(),
20+
"Selection mask length must equal the mask length"
21+
);
22+
23+
match selection_mask {
1924
Mask::AllTrue(_) => self.clone(),
2025
Mask::AllFalse(_) => BitBuffer::empty(),
2126
Mask::Values(v) => match v.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) {
2227
MaskIter::Indices(indices) => filter_indices(self, indices),
23-
MaskIter::Slices(slices) => filter_slices(self, mask.true_count(), slices),
28+
MaskIter::Slices(slices) => {
29+
filter_slices(self, selection_mask.true_count(), slices)
30+
}
2431
},
2532
}
2633
}

vortex-compute/src/filter/buffer.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,43 @@ impl<T: Copy> Filter for &Buffer<T> {
3232
}
3333
}
3434

35+
impl<T: Copy> Filter for &mut BufferMut<T> {
36+
type Output = ();
37+
38+
fn filter(self, selection_mask: &Mask) {
39+
assert_eq!(
40+
selection_mask.len(),
41+
self.len(),
42+
"Selection mask length must equal the buffer length"
43+
);
44+
45+
match selection_mask {
46+
Mask::AllTrue(_) => {}
47+
Mask::AllFalse(_) => self.clear(),
48+
Mask::Values(values) => {
49+
// We choose to _always_ use slices here because iterating over indices will have
50+
// strictly more loop iterations than slices, and the overhead over batched
51+
// `ptr::copy(len)` is not worth it.
52+
let slices = values.slices();
53+
54+
// SAFETY: We checked above that the selection mask has the same length as the
55+
// buffer.
56+
let new_len = unsafe { filter_slices_in_place(self.as_mut_slice(), slices) };
57+
58+
debug_assert!(
59+
new_len <= self.len(),
60+
"The new length was somehow larger after filter"
61+
);
62+
63+
// Truncate the buffer to the new length.
64+
// SAFETY: The new length cannot be larger than the old length, so all values must
65+
// be initialized.
66+
unsafe { self.set_len(new_len) };
67+
}
68+
}
69+
}
70+
}
71+
3572
impl<T: Copy> Filter for Buffer<T> {
3673
type Output = Self;
3774

@@ -66,6 +103,39 @@ fn filter_slices<T>(values: &[T], output_len: usize, slices: &[(usize, usize)])
66103
out.freeze()
67104
}
68105

106+
/// Filters a buffer in-place using slice ranges to determine which values to keep.
107+
///
108+
/// Returns the new length of the buffer.
109+
///
110+
/// # Safety
111+
///
112+
/// The slice ranges must be in the range of the `buffer`.
113+
#[must_use = "The caller should set the new length of the buffer"]
114+
unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut [T], slices: &[(usize, usize)]) -> usize {
115+
let mut write_pos = 0;
116+
117+
// For each range in the selection, copy all of the elements to the current write position.
118+
for &(start, end) in slices {
119+
// Note that we could add an if statement here that checks `if read_idx != write_idx`, but
120+
// it's probably better to just avoid the branch misprediction.
121+
122+
let len = end - start;
123+
124+
// SAFETY: The safety contract enforces that all ranges are within bounds.
125+
unsafe {
126+
core::ptr::copy(
127+
buffer.as_ptr().add(start),
128+
buffer.as_mut_ptr().add(write_pos),
129+
len,
130+
)
131+
};
132+
133+
write_pos += len;
134+
}
135+
136+
write_pos
137+
}
138+
69139
#[cfg(test)]
70140
mod tests {
71141
use vortex_buffer::buffer;
@@ -113,4 +183,112 @@ mod tests {
113183
let result = filter_slices(buf.as_slice(), 3, &[(0, 2), (4, 5)]);
114184
assert_eq!(result, buffer![1u32, 2, 5]);
115185
}
186+
187+
use vortex_buffer::{BufferMut, buffer_mut};
188+
189+
#[test]
190+
fn test_filter_all_true() {
191+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
192+
let mask = Mask::new_true(5);
193+
194+
buf.filter(&mask);
195+
assert_eq!(buf.as_slice(), &[1, 2, 3, 4, 5]);
196+
}
197+
198+
#[test]
199+
fn test_filter_all_false() {
200+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
201+
let mask = Mask::new_false(5);
202+
203+
buf.filter(&mask);
204+
assert!(buf.is_empty());
205+
}
206+
207+
#[test]
208+
fn test_filter_sparse() {
209+
let mut buf = buffer_mut![10u32, 20, 30, 40, 50];
210+
// Select indices 0, 2, 4 (sparse selection).
211+
let mask = Mask::from_iter([true, false, true, false, true]);
212+
213+
buf.filter(&mask);
214+
assert_eq!(buf.as_slice(), &[10, 30, 50]);
215+
}
216+
217+
#[test]
218+
fn test_filter_dense() {
219+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10];
220+
// Dense selection (80% selected).
221+
let mask = Mask::from_iter([true, true, true, true, false, true, true, true, false, true]);
222+
223+
buf.filter(&mask);
224+
assert_eq!(buf.as_slice(), &[1, 2, 3, 4, 6, 7, 8, 10]);
225+
}
226+
227+
#[test]
228+
fn test_filter_single_element_kept() {
229+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
230+
let mask = Mask::from_iter([false, false, true, false, false]);
231+
232+
buf.filter(&mask);
233+
assert_eq!(buf.as_slice(), &[3]);
234+
}
235+
236+
#[test]
237+
fn test_filter_first_last() {
238+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
239+
let mask = Mask::from_iter([true, false, false, false, true]);
240+
241+
buf.filter(&mask);
242+
assert_eq!(buf.as_slice(), &[1, 5]);
243+
}
244+
245+
#[test]
246+
fn test_filter_alternating() {
247+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6];
248+
let mask = Mask::from_iter([true, false, true, false, true, false]);
249+
250+
buf.filter(&mask);
251+
assert_eq!(buf.as_slice(), &[1, 3, 5]);
252+
}
253+
254+
#[test]
255+
fn test_filter_empty_buffer() {
256+
let mut buf: BufferMut<u32> = BufferMut::with_capacity(0);
257+
let mask = Mask::new_false(0);
258+
259+
buf.filter(&mask);
260+
assert!(buf.is_empty());
261+
}
262+
263+
#[test]
264+
fn test_filter_contiguous_regions() {
265+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10];
266+
// Two contiguous regions: [0..3] and [7..10].
267+
let mask = Mask::from_iter([
268+
true, true, true, false, false, false, false, true, true, true,
269+
]);
270+
271+
buf.filter(&mask);
272+
assert_eq!(buf.as_slice(), &[1, 2, 3, 8, 9, 10]);
273+
}
274+
275+
#[test]
276+
fn test_filter_large_buffer() {
277+
let mut buf: BufferMut<u32> = BufferMut::from_iter(0..1000);
278+
// Keep every third element.
279+
let mask = Mask::from_iter((0..1000).map(|i| i % 3 == 0));
280+
281+
buf.filter(&mask);
282+
let expected: Vec<u32> = (0..1000).filter(|i| i % 3 == 0).collect();
283+
assert_eq!(buf.as_slice(), &expected[..]);
284+
}
285+
286+
#[test]
287+
#[should_panic(expected = "Selection mask length must equal the buffer length")]
288+
fn test_filter_length_mismatch() {
289+
let mut buf = buffer_mut![1u32, 2, 3];
290+
let mask = Mask::new_true(5); // Wrong length.
291+
292+
buf.filter(&mask);
293+
}
116294
}

0 commit comments

Comments
 (0)