Skip to content

Commit aaf8553

Browse files
committed
Mooooaarrrrr generics
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 66c4249 commit aaf8553

File tree

14 files changed

+310
-218
lines changed

14 files changed

+310
-218
lines changed

vortex-array/src/pipeline/filter/buffer.rs

Lines changed: 0 additions & 51 deletions
This file was deleted.

vortex-array/src/pipeline/filter/mod.rs

Lines changed: 0 additions & 52 deletions
This file was deleted.

vortex-array/src/pipeline/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod driver;
5-
mod filter;
65

76
use vortex_error::{VortexExpect, VortexResult};
87
use vortex_vector::{Vector, VectorMut};

vortex-buffer/src/bit/view.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ impl<const NB: usize> BitView<'static, NB> {
5252
}
5353

5454
impl<'a, const NB: usize> BitView<'a, NB> {
55-
const N: usize = NB * 8;
56-
const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
55+
/// The number of bits in the view.
56+
pub const N: usize = NB * 8;
57+
/// The number of machine words in the view.
58+
pub const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
5759

5860
const _ASSERT_MULTIPLE_OF_8: () = assert!(
5961
NB % 8 == 0,

vortex-compute/src/filter/bitbuffer.rs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_buffer::{BitBuffer, BitBufferMut, get_bit};
4+
use vortex_buffer::{
5+
BitBuffer, BitBufferMut, BitView, get_bit, get_bit_unchecked, set_bit_unchecked,
6+
unset_bit_unchecked,
7+
};
58
use vortex_mask::Mask;
69

710
use crate::filter::{Filter, MaskIndices};
@@ -70,6 +73,56 @@ fn filter_indices(bools: &[u8], bit_offset: usize, indices: &[usize]) -> BitBuff
7073
})
7174
}
7275

76+
impl<const NB: usize> Filter<BitView<'_, NB>> for &BitBuffer {
77+
type Output = BitBuffer;
78+
79+
fn filter(self, selection: &BitView<'_, NB>) -> BitBuffer {
80+
let bits = self.inner().as_ptr();
81+
let mut out = BitBufferMut::with_capacity(selection.true_count());
82+
let mut out_idx = 0;
83+
selection.iter_ones(|idx| {
84+
let value = unsafe { get_bit_unchecked(bits, self.offset() + idx) };
85+
unsafe { out.set_to_unchecked(out_idx, value) };
86+
out_idx += 1;
87+
});
88+
out.freeze()
89+
}
90+
}
91+
92+
impl<const NB: usize> Filter<BitView<'_, NB>> for &mut BitBufferMut {
93+
type Output = ();
94+
95+
fn filter(self, selection: &BitView<'_, NB>) {
96+
assert_eq!(
97+
self.len(),
98+
BitView::<NB>::N,
99+
"Selection mask length must equal the mask length"
100+
);
101+
102+
let this = std::mem::take(self);
103+
104+
let offset = this.offset();
105+
let mut buffer = this.into_inner();
106+
107+
let buffer_ptr = buffer.as_mut_ptr();
108+
let mut out_idx = 0;
109+
selection.iter_ones(|idx| {
110+
let value = unsafe { get_bit_unchecked(buffer_ptr, offset + idx) };
111+
112+
// NOTE(ngates): we don't call out.set_bit_unchecked here because it's nice that we
113+
// can shift away any non-zero offset by writing directly into the bits buffer.
114+
if value {
115+
unsafe { set_bit_unchecked(buffer_ptr, out_idx) };
116+
} else {
117+
unsafe { unset_bit_unchecked(buffer_ptr, out_idx) };
118+
}
119+
out_idx += 1;
120+
});
121+
122+
*self = BitBufferMut::from_buffer(buffer, 0, selection.true_count());
123+
}
124+
}
125+
73126
#[cfg(test)]
74127
mod test {
75128
use vortex_buffer::bitbuffer;

vortex-compute/src/filter/buffer.rs

Lines changed: 38 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_buffer::{Buffer, BufferMut};
4+
use vortex_buffer::{BitView, Buffer, BufferMut};
55
use vortex_mask::{Mask, MaskIter};
66

77
use crate::filter::{Filter, MaskIndices};
88

9+
impl<M, T: Copy> Filter<M> for Buffer<T>
10+
where
11+
for<'a> &'a Buffer<T>: Filter<M, Output = Buffer<T>>,
12+
for<'a> &'a mut BufferMut<T>: Filter<M, Output = ()>,
13+
{
14+
type Output = Self;
15+
16+
fn filter(self, selection_mask: &M) -> Self {
17+
// If we have exclusive access, we can perform the filter in place.
18+
match self.try_into_mut() {
19+
Ok(mut buffer_mut) => {
20+
(&mut buffer_mut).filter(selection_mask);
21+
buffer_mut.freeze()
22+
}
23+
// Otherwise, allocate a new buffer and fill it in (delegate to the `&Buffer` impl).
24+
Err(buffer) => (&buffer).filter(selection_mask),
25+
}
26+
}
27+
}
28+
929
// This is modeled after the constant with the equivalent name in arrow-rs.
1030
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
1131

@@ -40,72 +60,29 @@ impl<T: Copy> Filter<MaskIndices<'_>> for &Buffer<T> {
4060
}
4161
}
4262

43-
impl<T: Copy> Filter<Mask> for &mut BufferMut<T> {
44-
type Output = ();
45-
46-
fn filter(self, selection_mask: &Mask) {
47-
assert_eq!(
48-
selection_mask.len(),
49-
self.len(),
50-
"Selection mask length must equal the buffer length"
51-
);
52-
53-
match selection_mask {
54-
Mask::AllTrue(_) => {}
55-
Mask::AllFalse(_) => self.clear(),
56-
Mask::Values(values) => {
57-
// We choose to _always_ use slices here because iterating over indices will have
58-
// strictly more loop iterations than slices, and the overhead over batched
59-
// `ptr::copy(len)` is not worth it.
60-
let slices = values.slices();
61-
62-
// SAFETY: We checked above that the selection mask has the same length as the
63-
// buffer.
64-
let new_len = unsafe { filter_slices_in_place(self.as_mut_slice(), slices) };
65-
66-
debug_assert!(
67-
new_len <= self.len(),
68-
"The new length was somehow larger after filter"
69-
);
70-
71-
// Truncate the buffer to the new length.
72-
// SAFETY: The new length cannot be larger than the old length, so all values must
73-
// be initialized.
74-
unsafe { self.set_len(new_len) };
75-
}
76-
}
77-
}
78-
}
79-
80-
impl<T: Copy> Filter<MaskIndices<'_>> for &mut BufferMut<T> {
81-
type Output = ();
82-
83-
fn filter(self, indices: &MaskIndices) -> Self::Output {
84-
for (write_index, &read_index) in indices.iter().enumerate() {
85-
self[write_index] = self[read_index];
86-
}
63+
impl<const NB: usize, T: Copy> Filter<BitView<'_, NB>> for &Buffer<T> {
64+
type Output = Buffer<T>;
8765

88-
self.truncate(indices.len());
66+
fn filter(self, selection: &BitView<'_, NB>) -> Self::Output {
67+
// TODO(ngates): this is very very slow!
68+
let elems = self.as_slice();
69+
let mut out = BufferMut::<T>::with_capacity(selection.true_count());
70+
selection.iter_ones(|idx| {
71+
unsafe { out.push_unchecked(elems[idx]) };
72+
});
73+
out.freeze()
8974
}
9075
}
9176

92-
impl<M, T: Copy> Filter<M> for Buffer<T>
77+
impl<M, T> Filter<M> for &mut BufferMut<T>
9378
where
94-
for<'a> &'a Buffer<T>: Filter<M, Output = Buffer<T>>,
95-
for<'a> &'a mut BufferMut<T>: Filter<M, Output = ()>,
79+
for<'a> &'a mut [T]: Filter<M, Output = &'a mut [T]>,
9680
{
97-
type Output = Self;
81+
type Output = ();
9882

99-
fn filter(self, selection_mask: &M) -> Self {
100-
// If we have exclusive access, we can perform the filter in place.
101-
match self.try_into_mut() {
102-
Ok(mut buffer_mut) => {
103-
(&mut buffer_mut).filter(selection_mask);
104-
buffer_mut.freeze()
105-
}
106-
// Otherwise, allocate a new buffer and fill it in (delegate to the `&Buffer` impl).
107-
Err(buffer) => (&buffer).filter(selection_mask),
108-
}
83+
fn filter(self, selection_mask: &M) -> Self::Output {
84+
let true_count = self.as_mut_slice().filter(selection_mask).len();
85+
self.truncate(true_count);
10986
}
11087
}
11188

@@ -121,42 +98,9 @@ fn filter_slices<T>(values: &[T], output_len: usize, slices: &[(usize, usize)])
12198
out.freeze()
12299
}
123100

124-
/// Filters a buffer in-place using slice ranges to determine which values to keep.
125-
///
126-
/// Returns the new length of the buffer.
127-
///
128-
/// # Safety
129-
///
130-
/// The slice ranges must be in the range of the `buffer`.
131-
#[must_use = "The caller should set the new length of the buffer"]
132-
unsafe fn filter_slices_in_place<T: Copy>(buffer: &mut [T], slices: &[(usize, usize)]) -> usize {
133-
let mut write_pos = 0;
134-
135-
// For each range in the selection, copy all of the elements to the current write position.
136-
for &(start, end) in slices {
137-
// Note that we could add an if statement here that checks `if read_idx != write_idx`, but
138-
// it's probably better to just avoid the branch misprediction.
139-
140-
let len = end - start;
141-
142-
// SAFETY: The safety contract enforces that all ranges are within bounds.
143-
unsafe {
144-
core::ptr::copy(
145-
buffer.as_ptr().add(start),
146-
buffer.as_mut_ptr().add(write_pos),
147-
len,
148-
)
149-
};
150-
151-
write_pos += len;
152-
}
153-
154-
write_pos
155-
}
156-
157101
#[cfg(test)]
158102
mod tests {
159-
use vortex_buffer::buffer;
103+
use vortex_buffer::{BufferMut, buffer, buffer_mut};
160104
use vortex_mask::Mask;
161105

162106
use super::*;
@@ -202,8 +146,6 @@ mod tests {
202146
assert_eq!(result, buffer![1u32, 2, 5]);
203147
}
204148

205-
use vortex_buffer::{BufferMut, buffer_mut};
206-
207149
#[test]
208150
fn test_filter_all_true() {
209151
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];

0 commit comments

Comments
 (0)