Skip to content

Commit 6e49a95

Browse files
authored
chore: implement Filter for new MaskIndices type for all vectors (#5333)
1 parent 3615760 commit 6e49a95

File tree

31 files changed

+1408
-53
lines changed

31 files changed

+1408
-53
lines changed

vortex-compute/src/filter/bitbuffer.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
use vortex_buffer::{BitBuffer, BitBufferMut, get_bit};
55
use vortex_mask::{Mask, MaskIter};
66

7-
use crate::filter::Filter;
7+
use crate::filter::{Filter, MaskIndices};
88

99
/// If the filter density is above 80%, we use slices to filter the array instead of indices.
1010
// TODO(ngates): we need more experimentation to determine the best threshold here.
1111
const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8;
1212

13-
impl Filter for &BitBuffer {
13+
impl Filter<Mask> for &BitBuffer {
1414
type Output = BitBuffer;
1515

1616
fn filter(self, selection_mask: &Mask) -> BitBuffer {
@@ -33,6 +33,14 @@ impl Filter for &BitBuffer {
3333
}
3434
}
3535

36+
impl Filter<MaskIndices<'_>> for &BitBuffer {
37+
type Output = BitBuffer;
38+
39+
fn filter(self, indices: &MaskIndices) -> BitBuffer {
40+
filter_indices(self, indices)
41+
}
42+
}
43+
3644
fn filter_indices(bools: &BitBuffer, indices: &[usize]) -> BitBuffer {
3745
let buffer = bools.inner().as_ref();
3846
BitBuffer::collect_bool(indices.len(), |idx| {

vortex-compute/src/filter/buffer.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
use vortex_buffer::{Buffer, BufferMut};
55
use vortex_mask::{Mask, MaskIter};
66

7-
use crate::filter::Filter;
7+
use crate::filter::{Filter, MaskIndices};
88

99
// This is modeled after the constant with the equivalent name in arrow-rs.
1010
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
1111

12-
impl<T: Copy> Filter for &Buffer<T> {
12+
impl<T: Copy> Filter<Mask> for &Buffer<T> {
1313
type Output = Buffer<T>;
1414

1515
fn filter(self, selection_mask: &Mask) -> Buffer<T> {
@@ -32,7 +32,15 @@ impl<T: Copy> Filter for &Buffer<T> {
3232
}
3333
}
3434

35-
impl<T: Copy> Filter for &mut BufferMut<T> {
35+
impl<T: Copy> Filter<MaskIndices<'_>> for &Buffer<T> {
36+
type Output = Buffer<T>;
37+
38+
fn filter(self, indices: &MaskIndices) -> Buffer<T> {
39+
filter_indices(self, indices)
40+
}
41+
}
42+
43+
impl<T: Copy> Filter<Mask> for &mut BufferMut<T> {
3644
type Output = ();
3745

3846
fn filter(self, selection_mask: &Mask) {
@@ -69,16 +77,26 @@ impl<T: Copy> Filter for &mut BufferMut<T> {
6977
}
7078
}
7179

72-
impl<T: Copy> Filter for Buffer<T> {
73-
type Output = Self;
80+
impl<T: Copy> Filter<MaskIndices<'_>> for &mut BufferMut<T> {
81+
type Output = ();
7482

75-
fn filter(self, selection_mask: &Mask) -> Self {
76-
assert_eq!(
77-
selection_mask.len(),
78-
self.len(),
79-
"Selection mask length must equal the buffer length"
80-
);
83+
fn filter(self, indices: &MaskIndices) -> Self::Output {
84+
for (write_index, &read_index) in indices.iter().enumerate() {
85+
self[write_index] = self[read_index];
86+
}
87+
88+
self.truncate(indices.len());
89+
}
90+
}
91+
92+
impl<M, T: Copy> Filter<M> for Buffer<T>
93+
where
94+
for<'a> &'a Buffer<T>: Filter<M, Output = Buffer<T>>,
95+
for<'a> &'a mut BufferMut<T>: Filter<M, Output = ()>,
96+
{
97+
type Output = Self;
8198

99+
fn filter(self, selection_mask: &M) -> Self {
82100
// If we have exclusive access, we can perform the filter in place.
83101
match self.try_into_mut() {
84102
Ok(mut buffer_mut) => {

vortex-compute/src/filter/mask.rs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_mask::Mask;
4+
use vortex_mask::{Mask, MaskMut};
55

6-
use crate::filter::Filter;
6+
use crate::filter::{Filter, MaskIndices};
77

8-
impl Filter for &Mask {
8+
impl Filter<Mask> for &Mask {
99
type Output = Mask;
1010

1111
fn filter(self, selection_mask: &Mask) -> Mask {
@@ -27,3 +27,41 @@ impl Filter for &Mask {
2727
}
2828
}
2929
}
30+
31+
impl Filter<MaskIndices<'_>> for &Mask {
32+
type Output = Mask;
33+
34+
fn filter(self, indices: &MaskIndices<'_>) -> Mask {
35+
match self {
36+
Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
37+
Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
38+
Mask::Values(mask_values) => Mask::from(mask_values.bit_buffer().filter(indices)),
39+
}
40+
}
41+
}
42+
43+
impl Filter<Mask> for &mut MaskMut {
44+
type Output = ();
45+
46+
fn filter(self, selection_mask: &Mask) {
47+
assert_eq!(
48+
selection_mask.len(),
49+
self.len(),
50+
"Selection mask length must equal the mask length"
51+
);
52+
53+
// TODO(connor): There is definitely a better way to do this (in place).
54+
let filtered = self.clone().freeze().filter(selection_mask).into_mut();
55+
*self = filtered;
56+
}
57+
}
58+
59+
impl Filter<MaskIndices<'_>> for &mut MaskMut {
60+
type Output = ();
61+
62+
fn filter(self, indices: &MaskIndices<'_>) -> Self::Output {
63+
// TODO(aduffy): Filter in-place
64+
let filtered = self.clone().freeze().filter(indices).into_mut();
65+
*self = filtered;
66+
}
67+
}

vortex-compute/src/filter/mod.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
//! Filter function.
55
6+
use std::ops::Deref;
7+
68
mod bitbuffer;
79
mod buffer;
810
mod mask;
911
mod vector;
1012

11-
use vortex_mask::Mask;
12-
1313
/// Function for filtering based on a selection mask.
14-
pub trait Filter {
14+
pub trait Filter<By: ?Sized> {
1515
/// The result type after performing the operation.
1616
type Output;
1717

@@ -22,5 +22,37 @@ pub trait Filter {
2222
/// # Panics
2323
///
2424
/// If the length of the mask does not equal the length of the value being filtered.
25-
fn filter(self, selection_mask: &Mask) -> Self::Output;
25+
fn filter(self, selection: &By) -> Self::Output;
26+
}
27+
28+
/// A view over a set of strictly sorted indices from a bit mask.
29+
///
30+
/// Unlike other indices, `MaskIndices` are always strict-sorted, meaning they are
31+
/// always unique and monotonic.
32+
///
33+
/// You can treat a `MaskIndices` just like a `&[usize]` by iterating or indexing
34+
/// into it just like you would a slice.
35+
pub struct MaskIndices<'a>(&'a [usize]);
36+
37+
impl<'a> MaskIndices<'a> {
38+
/// Create new indices from a slice of strict-sorted index values.
39+
///
40+
/// # Safety
41+
///
42+
/// The caller must ensure that the indices are strict-sorted, i.e. that they
43+
/// are monotonic and unique.
44+
///
45+
/// Users of the `Indices` type assume this and failure to uphold this guarantee
46+
/// can result in UB downstream.
47+
pub unsafe fn new_unchecked(indices: &'a [usize]) -> Self {
48+
Self(indices)
49+
}
50+
}
51+
52+
impl Deref for MaskIndices<'_> {
53+
type Target = [usize];
54+
55+
fn deref(&self) -> &Self::Target {
56+
self.0
57+
}
2658
}

0 commit comments

Comments
 (0)