Skip to content

Commit 3a1235b

Browse files
committed
use the existential types luke
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 22bee6b commit 3a1235b

File tree

7 files changed

+206
-251
lines changed

7 files changed

+206
-251
lines changed

vortex-compute/src/filter/buffer.rs

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,14 @@ impl<T: Copy> Filter<MaskIndices<'_>> for &mut BufferMut<T> {
8989
}
9090
}
9191

92-
impl<T: Copy> Filter<Mask> for Buffer<T> {
92+
impl<M, T: Copy> Filter<M> for Buffer<T>
93+
where
94+
for<'a> &'a Buffer<T>: Filter<M, Output = Buffer<T>>,
95+
for<'a> &'a mut BufferMut<T>: Filter<M, Output = ()>,
96+
{
9397
type Output = Self;
9498

95-
fn filter(self, selection_mask: &Mask) -> Self {
96-
assert_eq!(
97-
selection_mask.len(),
98-
self.len(),
99-
"Selection mask length must equal the buffer length"
100-
);
101-
99+
fn filter(self, selection_mask: &M) -> Self {
102100
// If we have exclusive access, we can perform the filter in place.
103101
match self.try_into_mut() {
104102
Ok(mut buffer_mut) => {
@@ -111,20 +109,6 @@ impl<T: Copy> Filter<Mask> for Buffer<T> {
111109
}
112110
}
113111

114-
impl<T: Copy> Filter<MaskIndices<'_>> for Buffer<T> {
115-
type Output = Self;
116-
117-
fn filter(self, indices: &MaskIndices<'_>) -> Self {
118-
match self.try_into_mut() {
119-
Ok(mut buffer_mut) => {
120-
(&mut buffer_mut).filter(indices);
121-
buffer_mut.freeze()
122-
}
123-
Err(buffer) => (&buffer).filter(indices),
124-
}
125-
}
126-
}
127-
128112
fn filter_indices<T: Copy>(values: &[T], indices: &[usize]) -> Buffer<T> {
129113
Buffer::<T>::from_trusted_len_iter(indices.iter().map(|&idx| values[idx]))
130114
}

vortex-compute/src/filter/vector/binaryview.rs

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,44 @@
1-
use vortex_mask::Mask;
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::{Buffer, BufferMut};
5+
use vortex_mask::{Mask, MaskMut};
26
use vortex_vector::VectorOps;
3-
use vortex_vector::binaryview::{BinaryViewType, BinaryViewVector, BinaryViewVectorMut};
7+
use vortex_vector::binaryview::{
8+
BinaryView, BinaryViewType, BinaryViewVector, BinaryViewVectorMut,
9+
};
410

5-
use crate::filter::{Filter, MaskIndices};
11+
use crate::filter::Filter;
612

7-
macro_rules! delegate_filter_impl {
8-
($mask_ty:ty) => {
9-
impl<T: BinaryViewType> Filter<$mask_ty> for &BinaryViewVector<T> {
10-
type Output = BinaryViewVector<T>;
13+
impl<M, T: BinaryViewType> Filter<M> for &BinaryViewVector<T>
14+
where
15+
for<'a> &'a Mask: Filter<M, Output = Mask>,
16+
for<'a> &'a Buffer<BinaryView>: Filter<M, Output = Buffer<BinaryView>>,
17+
{
18+
type Output = BinaryViewVector<T>;
1119

12-
fn filter(self, selection: &$mask_ty) -> Self::Output {
13-
let views = self.views().filter(selection);
14-
let validity = self.validity().filter(selection);
20+
fn filter(self, selection: &M) -> Self::Output {
21+
let views = self.views().filter(selection);
22+
let validity = self.validity().filter(selection);
1523

16-
// SAFETY: we filter the views and validity using the same mask
17-
unsafe {
18-
BinaryViewVector::<T>::new_unchecked(views, self.buffers().clone(), validity)
19-
}
20-
}
21-
}
24+
// SAFETY: we filter the views and validity using the same mask
25+
unsafe { BinaryViewVector::<T>::new_unchecked(views, self.buffers().clone(), validity) }
26+
}
27+
}
2228

23-
impl<T: BinaryViewType> Filter<$mask_ty> for &mut BinaryViewVectorMut<T> {
24-
type Output = ();
29+
impl<M, T: BinaryViewType> Filter<M> for &mut BinaryViewVectorMut<T>
30+
where
31+
for<'a> &'a mut MaskMut: Filter<M, Output = ()>,
32+
for<'a> &'a mut BufferMut<BinaryView>: Filter<M, Output = ()>,
33+
{
34+
type Output = ();
2535

26-
fn filter(self, selection: &$mask_ty) -> Self::Output {
27-
// SAFETY: views and validity filtered by the same mask will have
28-
// same resultant length.
29-
unsafe {
30-
self.views_mut().filter(selection);
31-
self.validity_mut().filter(selection);
32-
}
33-
}
36+
fn filter(self, selection: &M) -> Self::Output {
37+
// SAFETY: views and validity filtered by the same mask will have
38+
// same resultant length.
39+
unsafe {
40+
self.views_mut().filter(selection);
41+
self.validity_mut().filter(selection);
3442
}
35-
};
43+
}
3644
}
37-
38-
delegate_filter_impl!(Mask);
39-
delegate_filter_impl!(MaskIndices<'_>);

vortex-compute/src/filter/vector/bool.rs

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,44 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_mask::Mask;
4+
use vortex_buffer::BitBuffer;
5+
use vortex_mask::{Mask, MaskMut};
56
use vortex_vector::VectorOps;
67
use vortex_vector::bool::{BoolVector, BoolVectorMut};
78

8-
use crate::filter::{Filter, MaskIndices};
9+
use crate::filter::Filter;
910

10-
macro_rules! delegate_filter_impl {
11-
($mask_ty:ty) => {
12-
impl Filter<$mask_ty> for &BoolVector {
13-
type Output = BoolVector;
11+
impl<M> Filter<M> for &BoolVector
12+
where
13+
for<'a> &'a BitBuffer: Filter<M, Output = BitBuffer>,
14+
for<'a> &'a Mask: Filter<M, Output = Mask>,
15+
{
16+
type Output = BoolVector;
1417

15-
fn filter(self, selection: &$mask_ty) -> Self::Output {
16-
let filtered_bits = self.bits().filter(selection);
17-
let filtered_validity = self.validity().filter(selection);
18+
fn filter(self, selection: &M) -> Self::Output {
19+
let filtered_bits = self.bits().filter(selection);
20+
let filtered_validity = self.validity().filter(selection);
1821

19-
// SAFETY: We filter the bits and validity with the same mask, and since they came from an
20-
// existing and valid `BoolVector`, we know that the filtered output must have the same
21-
// length.
22-
unsafe { BoolVector::new_unchecked(filtered_bits, filtered_validity) }
23-
}
24-
}
25-
26-
impl Filter<$mask_ty> for &mut BoolVectorMut {
27-
type Output = ();
22+
// SAFETY: We filter the bits and validity with the same mask, and since they came from an
23+
// existing and valid `BoolVector`, we know that the filtered output must have the same
24+
// length.
25+
unsafe { BoolVector::new_unchecked(filtered_bits, filtered_validity) }
26+
}
27+
}
2828

29-
fn filter(self, selection: &$mask_ty) -> Self::Output {
30-
// TODO(aduffy): how can we do this faster in-place?
31-
unsafe {
32-
let bits = self.bits_mut();
33-
*bits = (*bits).clone().freeze().filter(selection).into_mut();
34-
self.validity_mut().filter(selection);
35-
}
36-
}
29+
impl<M> Filter<M> for &mut BoolVectorMut
30+
where
31+
for<'a> &'a BoolVector: Filter<M, Output = BoolVector>,
32+
for<'a> &'a mut MaskMut: Filter<M, Output = ()>,
33+
{
34+
type Output = ();
35+
36+
fn filter(self, selection: &M) -> Self::Output {
37+
// TODO(aduffy): how can we do this faster in-place?
38+
unsafe {
39+
let bits = self.bits_mut();
40+
*bits = (*bits).clone().freeze().filter(selection).into_mut();
41+
self.validity_mut().filter(selection);
3742
}
38-
};
43+
}
3944
}
40-
41-
delegate_filter_impl!(Mask);
42-
delegate_filter_impl!(MaskIndices<'_>);

vortex-compute/src/filter/vector/dvector.rs

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,41 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::{Buffer, BufferMut};
15
use vortex_dtype::NativeDecimalType;
2-
use vortex_mask::Mask;
6+
use vortex_mask::{Mask, MaskMut};
37
use vortex_vector::VectorOps;
48
use vortex_vector::decimal::{DVector, DVectorMut};
59

6-
use crate::filter::{Filter, MaskIndices};
10+
use crate::filter::Filter;
711

8-
macro_rules! delegate_filter_impl {
9-
($mask_ty:ty) => {
10-
impl<D: NativeDecimalType> Filter<$mask_ty> for &DVector<D> {
11-
type Output = DVector<D>;
12+
impl<M, D: NativeDecimalType> Filter<M> for &DVector<D>
13+
where
14+
for<'a> &'a Buffer<D>: Filter<M, Output = Buffer<D>>,
15+
for<'a> &'a Mask: Filter<M, Output = Mask>,
16+
{
17+
type Output = DVector<D>;
1218

13-
fn filter(self, selection: &$mask_ty) -> Self::Output {
14-
let elements = self.elements().filter(selection);
15-
let validity = self.validity().filter(selection);
16-
// SAFETY: we're filtering the elements and validity with the same mask
17-
unsafe { DVector::<D>::new_unchecked(self.precision_scale(), elements, validity) }
18-
}
19-
}
19+
fn filter(self, selection: &M) -> Self::Output {
20+
let elements = self.elements().filter(selection);
21+
let validity = self.validity().filter(selection);
22+
// SAFETY: we're filtering the elements and validity with the same mask
23+
unsafe { DVector::<D>::new_unchecked(self.precision_scale(), elements, validity) }
24+
}
25+
}
2026

21-
impl<D: NativeDecimalType> Filter<$mask_ty> for &mut DVectorMut<D> {
22-
type Output = ();
27+
impl<M, D: NativeDecimalType> Filter<M> for &mut DVectorMut<D>
28+
where
29+
for<'a> &'a mut BufferMut<D>: Filter<M, Output = ()>,
30+
for<'a> &'a mut MaskMut: Filter<M, Output = ()>,
31+
{
32+
type Output = ();
2333

24-
fn filter(self, selection: &$mask_ty) -> Self::Output {
25-
// SAFETY: we filter elements and validity using the same mask
26-
unsafe {
27-
self.elements_mut().filter(selection);
28-
self.validity_mut().filter(selection);
29-
}
30-
}
34+
fn filter(self, selection: &M) -> Self::Output {
35+
// SAFETY: we filter elements and validity using the same mask
36+
unsafe {
37+
self.elements_mut().filter(selection);
38+
self.validity_mut().filter(selection);
3139
}
32-
};
40+
}
3341
}
34-
35-
delegate_filter_impl!(Mask);
36-
delegate_filter_impl!(MaskIndices<'_>);
Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,47 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
14
use std::sync::Arc;
25

3-
use vortex_mask::Mask;
6+
use vortex_mask::{Mask, MaskMut};
47
use vortex_vector::VectorOps;
58
use vortex_vector::listview::{ListViewVector, ListViewVectorMut};
9+
use vortex_vector::primitive::{PrimitiveVector, PrimitiveVectorMut};
610

7-
use crate::filter::{Filter, MaskIndices};
8-
9-
macro_rules! delegate_filter_impl {
10-
($mask_ty:ty) => {
11-
impl Filter<$mask_ty> for &ListViewVector {
12-
type Output = ListViewVector;
13-
14-
fn filter(self, selection: &$mask_ty) -> Self::Output {
15-
let offsets = self.offsets().filter(selection);
16-
let sizes = self.sizes().filter(selection);
17-
let validity = self.validity().filter(selection);
18-
19-
// SAFETY: all components filtered with same mask
20-
unsafe {
21-
ListViewVector::new_unchecked(
22-
Arc::clone(self.elements()),
23-
offsets,
24-
sizes,
25-
validity,
26-
)
27-
}
28-
}
29-
}
11+
use crate::filter::Filter;
12+
13+
impl<M> Filter<M> for &ListViewVector
14+
where
15+
for<'a> &'a PrimitiveVector: Filter<M, Output = PrimitiveVector>,
16+
for<'a> &'a Mask: Filter<M, Output = Mask>,
17+
{
18+
type Output = ListViewVector;
3019

31-
impl Filter<$mask_ty> for &mut ListViewVectorMut {
32-
type Output = ();
33-
34-
fn filter(self, selection: &$mask_ty) -> Self::Output {
35-
// SAFETY: offsets, sizes, validity all being filtered with same mask
36-
unsafe {
37-
self.offsets_mut().filter(selection);
38-
self.sizes_mut().filter(selection);
39-
self.validity_mut().filter(selection);
40-
}
41-
}
20+
fn filter(self, selection: &M) -> Self::Output {
21+
let offsets = self.offsets().filter(selection);
22+
let sizes = self.sizes().filter(selection);
23+
let validity = self.validity().filter(selection);
24+
25+
// SAFETY: all components filtered with same mask
26+
unsafe {
27+
ListViewVector::new_unchecked(Arc::clone(self.elements()), offsets, sizes, validity)
4228
}
43-
};
29+
}
4430
}
4531

46-
delegate_filter_impl!(Mask);
47-
delegate_filter_impl!(MaskIndices<'_>);
32+
impl<M> Filter<M> for &mut ListViewVectorMut
33+
where
34+
for<'a> &'a mut PrimitiveVectorMut: Filter<M, Output = ()>,
35+
for<'a> &'a mut MaskMut: Filter<M, Output = ()>,
36+
{
37+
type Output = ();
38+
39+
fn filter(self, selection: &M) -> Self::Output {
40+
// SAFETY: offsets, sizes, validity all being filtered with same mask
41+
unsafe {
42+
self.offsets_mut().filter(selection);
43+
self.sizes_mut().filter(selection);
44+
self.validity_mut().filter(selection);
45+
}
46+
}
47+
}

0 commit comments

Comments
 (0)