Skip to content

Commit 22bee6b

Browse files
committed
chore: implement Filter for new MaskIndices type for all vectors
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 31b5671 commit 22bee6b

File tree

22 files changed

+713
-45
lines changed

22 files changed

+713
-45
lines changed

vortex-compute/src/filter/bitbuffer.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
use vortex_buffer::{BitBuffer, BitBufferMut, get_bit};
55
use vortex_mask::{Mask, MaskIter};
66

7-
use crate::filter::Filter;
7+
use crate::filter::{Filter, MaskIndices};
88

99
/// If the filter density is above 80%, we use slices to filter the array instead of indices.
1010
// TODO(ngates): we need more experimentation to determine the best threshold here.
1111
const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8;
1212

13-
impl Filter for &BitBuffer {
13+
impl Filter<Mask> for &BitBuffer {
1414
type Output = BitBuffer;
1515

1616
fn filter(self, selection_mask: &Mask) -> BitBuffer {
@@ -33,6 +33,14 @@ impl Filter for &BitBuffer {
3333
}
3434
}
3535

36+
impl Filter<MaskIndices<'_>> for &BitBuffer {
37+
type Output = BitBuffer;
38+
39+
fn filter(self, indices: &MaskIndices) -> BitBuffer {
40+
filter_indices(self, indices)
41+
}
42+
}
43+
3644
fn filter_indices(bools: &BitBuffer, indices: &[usize]) -> BitBuffer {
3745
let buffer = bools.inner().as_ref();
3846
BitBuffer::collect_bool(indices.len(), |idx| {

vortex-compute/src/filter/buffer.rs

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
use vortex_buffer::{Buffer, BufferMut};
55
use vortex_mask::{Mask, MaskIter};
66

7-
use crate::filter::Filter;
7+
use crate::filter::{Filter, MaskIndices};
88

99
// This is modeled after the constant with the equivalent name in arrow-rs.
1010
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
1111

12-
impl<T: Copy> Filter for &Buffer<T> {
12+
impl<T: Copy> Filter<Mask> for &Buffer<T> {
1313
type Output = Buffer<T>;
1414

1515
fn filter(self, selection_mask: &Mask) -> Buffer<T> {
@@ -32,7 +32,15 @@ impl<T: Copy> Filter for &Buffer<T> {
3232
}
3333
}
3434

35-
impl<T: Copy> Filter for &mut BufferMut<T> {
35+
impl<T: Copy> Filter<MaskIndices<'_>> for &Buffer<T> {
36+
type Output = Buffer<T>;
37+
38+
fn filter(self, indices: &MaskIndices) -> Buffer<T> {
39+
filter_indices(self, indices)
40+
}
41+
}
42+
43+
impl<T: Copy> Filter<Mask> for &mut BufferMut<T> {
3644
type Output = ();
3745

3846
fn filter(self, selection_mask: &Mask) {
@@ -69,7 +77,19 @@ impl<T: Copy> Filter for &mut BufferMut<T> {
6977
}
7078
}
7179

72-
impl<T: Copy> Filter for Buffer<T> {
80+
impl<T: Copy> Filter<MaskIndices<'_>> for &mut BufferMut<T> {
81+
type Output = ();
82+
83+
fn filter(self, indices: &MaskIndices) -> Self::Output {
84+
for (write_index, &read_index) in indices.iter().enumerate() {
85+
self[write_index] = self[read_index];
86+
}
87+
88+
self.truncate(indices.len());
89+
}
90+
}
91+
92+
impl<T: Copy> Filter<Mask> for Buffer<T> {
7393
type Output = Self;
7494

7595
fn filter(self, selection_mask: &Mask) -> Self {
@@ -91,6 +111,20 @@ impl<T: Copy> Filter for Buffer<T> {
91111
}
92112
}
93113

114+
impl<T: Copy> Filter<MaskIndices<'_>> for Buffer<T> {
115+
type Output = Self;
116+
117+
fn filter(self, indices: &MaskIndices<'_>) -> Self {
118+
match self.try_into_mut() {
119+
Ok(mut buffer_mut) => {
120+
(&mut buffer_mut).filter(indices);
121+
buffer_mut.freeze()
122+
}
123+
Err(buffer) => (&buffer).filter(indices),
124+
}
125+
}
126+
}
127+
94128
fn filter_indices<T: Copy>(values: &[T], indices: &[usize]) -> Buffer<T> {
95129
Buffer::<T>::from_trusted_len_iter(indices.iter().map(|&idx| values[idx]))
96130
}

vortex-compute/src/filter/mask.rs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_mask::Mask;
4+
use vortex_mask::{Mask, MaskMut};
55

6-
use crate::filter::Filter;
6+
use crate::filter::{Filter, MaskIndices};
77

8-
impl Filter for &Mask {
8+
impl Filter<Mask> for &Mask {
99
type Output = Mask;
1010

1111
fn filter(self, selection_mask: &Mask) -> Mask {
@@ -27,3 +27,41 @@ impl Filter for &Mask {
2727
}
2828
}
2929
}
30+
31+
impl Filter<MaskIndices<'_>> for &Mask {
32+
type Output = Mask;
33+
34+
fn filter(self, indices: &MaskIndices<'_>) -> Mask {
35+
match self {
36+
Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
37+
Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
38+
Mask::Values(mask_values) => Mask::from(mask_values.bit_buffer().filter(indices)),
39+
}
40+
}
41+
}
42+
43+
impl Filter<Mask> for &mut MaskMut {
44+
type Output = ();
45+
46+
fn filter(self, selection_mask: &Mask) {
47+
assert_eq!(
48+
selection_mask.len(),
49+
self.len(),
50+
"Selection mask length must equal the mask length"
51+
);
52+
53+
// TODO(connor): There is definitely a better way to do this (in place).
54+
let filtered = self.clone().freeze().filter(selection_mask).into_mut();
55+
*self = filtered;
56+
}
57+
}
58+
59+
impl Filter<MaskIndices<'_>> for &mut MaskMut {
60+
type Output = ();
61+
62+
fn filter(self, indices: &MaskIndices<'_>) -> Self::Output {
63+
// TODO(aduffy): Filter in-place
64+
let filtered = self.clone().freeze().filter(indices).into_mut();
65+
*self = filtered;
66+
}
67+
}

vortex-compute/src/filter/mod.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
//! Filter function.
55
6+
use std::ops::Deref;
7+
68
mod bitbuffer;
79
mod buffer;
810
mod mask;
911
mod vector;
1012

11-
use vortex_mask::Mask;
12-
1313
/// Function for filtering based on a selection mask.
14-
pub trait Filter {
14+
pub trait Filter<By: ?Sized> {
1515
/// The result type after performing the operation.
1616
type Output;
1717

@@ -22,5 +22,37 @@ pub trait Filter {
2222
/// # Panics
2323
///
2424
/// If the length of the mask does not equal the length of the value being filtered.
25-
fn filter(self, selection_mask: &Mask) -> Self::Output;
25+
fn filter(self, selection: &By) -> Self::Output;
26+
}
27+
28+
/// A view over a set of strictly sorted indices from a bit mask.
29+
///
30+
/// Unlike other indices, `MaskIndices` are always strict-sorted, meaning they are
31+
/// always unique and monotonic.
32+
///
33+
/// You can treat a `MaskIndices` just like a `&[usize]` by iterating or indexing
34+
/// into it just like you would a slice.
35+
pub struct MaskIndices<'a>(&'a [usize]);
36+
37+
impl<'a> MaskIndices<'a> {
38+
/// Create new indices from a slice of strict-sorted index values.
39+
///
40+
/// # Safety
41+
///
42+
/// The caller must ensure that the indices are strict-sorted, i.e. that they
43+
/// are montonic and unique.
44+
///
45+
/// Users of the `Indices` type assume this and failure to uphold this guarantee
46+
/// can result in UB downstream.
47+
pub unsafe fn new_unchecked(indices: &'a [usize]) -> Self {
48+
Self(indices)
49+
}
50+
}
51+
52+
impl Deref for MaskIndices<'_> {
53+
type Target = [usize];
54+
55+
fn deref(&self) -> &Self::Target {
56+
self.0
57+
}
2658
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use vortex_mask::Mask;
2+
use vortex_vector::VectorOps;
3+
use vortex_vector::binaryview::{BinaryViewType, BinaryViewVector, BinaryViewVectorMut};
4+
5+
use crate::filter::{Filter, MaskIndices};
6+
7+
macro_rules! delegate_filter_impl {
8+
($mask_ty:ty) => {
9+
impl<T: BinaryViewType> Filter<$mask_ty> for &BinaryViewVector<T> {
10+
type Output = BinaryViewVector<T>;
11+
12+
fn filter(self, selection: &$mask_ty) -> Self::Output {
13+
let views = self.views().filter(selection);
14+
let validity = self.validity().filter(selection);
15+
16+
// SAFETY: we filter the views and validity using the same mask
17+
unsafe {
18+
BinaryViewVector::<T>::new_unchecked(views, self.buffers().clone(), validity)
19+
}
20+
}
21+
}
22+
23+
impl<T: BinaryViewType> Filter<$mask_ty> for &mut BinaryViewVectorMut<T> {
24+
type Output = ();
25+
26+
fn filter(self, selection: &$mask_ty) -> Self::Output {
27+
// SAFETY: views and validity filtered by the same mask will have
28+
// same resultant length.
29+
unsafe {
30+
self.views_mut().filter(selection);
31+
self.validity_mut().filter(selection);
32+
}
33+
}
34+
}
35+
};
36+
}
37+
38+
delegate_filter_impl!(Mask);
39+
delegate_filter_impl!(MaskIndices<'_>);

vortex-compute/src/filter/vector/bool.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,40 @@
33

44
use vortex_mask::Mask;
55
use vortex_vector::VectorOps;
6-
use vortex_vector::bool::BoolVector;
6+
use vortex_vector::bool::{BoolVector, BoolVectorMut};
77

8-
use crate::filter::Filter;
8+
use crate::filter::{Filter, MaskIndices};
99

10-
impl Filter for &BoolVector {
11-
type Output = BoolVector;
10+
macro_rules! delegate_filter_impl {
11+
($mask_ty:ty) => {
12+
impl Filter<$mask_ty> for &BoolVector {
13+
type Output = BoolVector;
1214

13-
fn filter(self, mask: &Mask) -> BoolVector {
14-
let filtered_bits = self.bits().filter(mask);
15-
let filtered_validity = self.validity().filter(mask);
15+
fn filter(self, selection: &$mask_ty) -> Self::Output {
16+
let filtered_bits = self.bits().filter(selection);
17+
let filtered_validity = self.validity().filter(selection);
1618

17-
// SAFETY: We filter the bits and validity with the same mask, and since they came from an
18-
// existing and valid `BoolVector`, we know that the filtered output must have the same
19-
// length.
20-
unsafe { BoolVector::new_unchecked(filtered_bits, filtered_validity) }
21-
}
19+
// SAFETY: We filter the bits and validity with the same mask, and since they came from an
20+
// existing and valid `BoolVector`, we know that the filtered output must have the same
21+
// length.
22+
unsafe { BoolVector::new_unchecked(filtered_bits, filtered_validity) }
23+
}
24+
}
25+
26+
impl Filter<$mask_ty> for &mut BoolVectorMut {
27+
type Output = ();
28+
29+
fn filter(self, selection: &$mask_ty) -> Self::Output {
30+
// TODO(aduffy): how can we do this faster in-place?
31+
unsafe {
32+
let bits = self.bits_mut();
33+
*bits = (*bits).clone().freeze().filter(selection).into_mut();
34+
self.validity_mut().filter(selection);
35+
}
36+
}
37+
}
38+
};
2239
}
40+
41+
delegate_filter_impl!(Mask);
42+
delegate_filter_impl!(MaskIndices<'_>);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use vortex_mask::Mask;
2+
use vortex_vector::decimal::{DecimalVector, DecimalVectorMut};
3+
use vortex_vector::{match_each_dvector, match_each_dvector_mut};
4+
5+
use crate::filter::{Filter, MaskIndices};
6+
7+
impl Filter<Mask> for &DecimalVector {
8+
type Output = DecimalVector;
9+
10+
fn filter(self, selection: &Mask) -> Self::Output {
11+
match_each_dvector!(self, |d| { d.filter(selection).into() })
12+
}
13+
}
14+
15+
impl Filter<MaskIndices<'_>> for &DecimalVector {
16+
type Output = DecimalVector;
17+
18+
fn filter(self, selection: &MaskIndices) -> Self::Output {
19+
match_each_dvector!(self, |d| { d.filter(selection).into() })
20+
}
21+
}
22+
23+
impl Filter<Mask> for &mut DecimalVectorMut {
24+
type Output = ();
25+
26+
fn filter(self, selection: &Mask) -> Self::Output {
27+
match_each_dvector_mut!(self, |d| { d.filter(selection) });
28+
}
29+
}
30+
31+
impl Filter<MaskIndices<'_>> for &mut DecimalVectorMut {
32+
type Output = ();
33+
34+
fn filter(self, selection: &MaskIndices) -> Self::Output {
35+
match_each_dvector_mut!(self, |d| { d.filter(selection) });
36+
}
37+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use vortex_dtype::NativeDecimalType;
2+
use vortex_mask::Mask;
3+
use vortex_vector::VectorOps;
4+
use vortex_vector::decimal::{DVector, DVectorMut};
5+
6+
use crate::filter::{Filter, MaskIndices};
7+
8+
macro_rules! delegate_filter_impl {
9+
($mask_ty:ty) => {
10+
impl<D: NativeDecimalType> Filter<$mask_ty> for &DVector<D> {
11+
type Output = DVector<D>;
12+
13+
fn filter(self, selection: &$mask_ty) -> Self::Output {
14+
let elements = self.elements().filter(selection);
15+
let validity = self.validity().filter(selection);
16+
// SAFETY: we're filtering the elements and validity with the same mask
17+
unsafe { DVector::<D>::new_unchecked(self.precision_scale(), elements, validity) }
18+
}
19+
}
20+
21+
impl<D: NativeDecimalType> Filter<$mask_ty> for &mut DVectorMut<D> {
22+
type Output = ();
23+
24+
fn filter(self, selection: &$mask_ty) -> Self::Output {
25+
// SAFETY: we filter elements and validity using the same mask
26+
unsafe {
27+
self.elements_mut().filter(selection);
28+
self.validity_mut().filter(selection);
29+
}
30+
}
31+
}
32+
};
33+
}
34+
35+
delegate_filter_impl!(Mask);
36+
delegate_filter_impl!(MaskIndices<'_>);

0 commit comments

Comments
 (0)