|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +use vortex_buffer::{Buffer, BufferMut}; |
| 5 | +use vortex_mask::{Mask, MaskIter}; |
| 6 | + |
| 7 | +use crate::filter::Filter; |
| 8 | + |
| 9 | +// This is modeled after the constant with the equivalent name in arrow-rs. |
| 10 | +const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; |
| 11 | + |
| 12 | +impl<T: Copy> Filter for Buffer<T> { |
| 13 | + fn filter(&self, mask: &Mask) -> Self { |
| 14 | + assert_eq!(mask.len(), self.len()); |
| 15 | + match mask { |
| 16 | + Mask::AllTrue(_) => self.clone(), |
| 17 | + Mask::AllFalse(_) => Self::empty(), |
| 18 | + Mask::Values(v) => match v.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { |
| 19 | + MaskIter::Indices(indices) => filter_indices(self.as_slice(), indices), |
| 20 | + MaskIter::Slices(slices) => { |
| 21 | + filter_slices(self.as_slice(), mask.true_count(), slices) |
| 22 | + } |
| 23 | + }, |
| 24 | + } |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +fn filter_indices<T: Copy>(values: &[T], indices: &[usize]) -> Buffer<T> { |
| 29 | + Buffer::<T>::from_trusted_len_iter(indices.iter().map(|&idx| values[idx])) |
| 30 | +} |
| 31 | + |
| 32 | +fn filter_slices<T>(values: &[T], output_len: usize, slices: &[(usize, usize)]) -> Buffer<T> { |
| 33 | + let mut out = BufferMut::<T>::with_capacity(output_len); |
| 34 | + for (start, end) in slices { |
| 35 | + out.extend_from_slice(&values[*start..*end]); |
| 36 | + } |
| 37 | + out.freeze() |
| 38 | +} |
| 39 | + |
| 40 | +#[cfg(test)] |
| 41 | +mod tests { |
| 42 | + use vortex_buffer::buffer; |
| 43 | + use vortex_mask::Mask; |
| 44 | + |
| 45 | + use super::*; |
| 46 | + |
| 47 | + #[test] |
| 48 | + fn test_filter_buffer_by_indices() { |
| 49 | + let buf = buffer![10u32, 20, 30, 40, 50]; |
| 50 | + let mask = Mask::from_iter([true, false, true, false, true]); |
| 51 | + |
| 52 | + let result = buf.filter(&mask); |
| 53 | + assert_eq!(result, buffer![10u32, 30, 50]); |
| 54 | + } |
| 55 | + |
| 56 | + #[test] |
| 57 | + fn test_filter_buffer_all_true() { |
| 58 | + let buf = buffer![1u64, 2, 3]; |
| 59 | + let mask = Mask::new_true(3); |
| 60 | + |
| 61 | + let result = buf.filter(&mask); |
| 62 | + assert_eq!(result, buffer![1u64, 2, 3]); |
| 63 | + } |
| 64 | + |
| 65 | + #[test] |
| 66 | + fn test_filter_buffer_all_false() { |
| 67 | + let buf = buffer![1i32, 2, 3, 4]; |
| 68 | + let mask = Mask::new_false(4); |
| 69 | + |
| 70 | + let result = buf.filter(&mask); |
| 71 | + assert!(result.is_empty()); |
| 72 | + } |
| 73 | + |
| 74 | + #[test] |
| 75 | + fn test_filter_indices_direct() { |
| 76 | + let buf = buffer![100u32, 200, 300, 400]; |
| 77 | + let result = filter_indices(buf.as_slice(), &[0, 2, 3]); |
| 78 | + assert_eq!(result, buffer![100u32, 300, 400]); |
| 79 | + } |
| 80 | + |
| 81 | + #[test] |
| 82 | + fn test_filter_slices_direct() { |
| 83 | + let buf = buffer![1u32, 2, 3, 4, 5]; |
| 84 | + let result = filter_slices(buf.as_slice(), 3, &[(0, 2), (4, 5)]); |
| 85 | + assert_eq!(result, buffer![1u32, 2, 5]); |
| 86 | + } |
| 87 | +} |
0 commit comments