11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use vortex_buffer:: { Buffer , BufferMut } ;
4+ use vortex_buffer:: { BitView , Buffer , BufferMut } ;
55use vortex_mask:: { Mask , MaskIter } ;
66
77use crate :: filter:: { Filter , MaskIndices } ;
88
9+ impl < M , T : Copy > Filter < M > for Buffer < T >
10+ where
11+ for < ' a > & ' a Buffer < T > : Filter < M , Output = Buffer < T > > ,
12+ for < ' a > & ' a mut BufferMut < T > : Filter < M , Output = ( ) > ,
13+ {
14+ type Output = Self ;
15+
16+ fn filter ( self , selection_mask : & M ) -> Self {
17+ // If we have exclusive access, we can perform the filter in place.
18+ match self . try_into_mut ( ) {
19+ Ok ( mut buffer_mut) => {
20+ ( & mut buffer_mut) . filter ( selection_mask) ;
21+ buffer_mut. freeze ( )
22+ }
23+ // Otherwise, allocate a new buffer and fill it in (delegate to the `&Buffer` impl).
24+ Err ( buffer) => ( & buffer) . filter ( selection_mask) ,
25+ }
26+ }
27+ }
28+
929// This is modeled after the constant with the equivalent name in arrow-rs.
1030const FILTER_SLICES_SELECTIVITY_THRESHOLD : f64 = 0.8 ;
1131
@@ -40,72 +60,29 @@ impl<T: Copy> Filter<MaskIndices<'_>> for &Buffer<T> {
4060 }
4161}
4262
43- impl < T : Copy > Filter < Mask > for & mut BufferMut < T > {
44- type Output = ( ) ;
45-
46- fn filter ( self , selection_mask : & Mask ) {
47- assert_eq ! (
48- selection_mask. len( ) ,
49- self . len( ) ,
50- "Selection mask length must equal the buffer length"
51- ) ;
52-
53- match selection_mask {
54- Mask :: AllTrue ( _) => { }
55- Mask :: AllFalse ( _) => self . clear ( ) ,
56- Mask :: Values ( values) => {
57- // We choose to _always_ use slices here because iterating over indices will have
58- // strictly more loop iterations than slices, and the overhead over batched
59- // `ptr::copy(len)` is not worth it.
60- let slices = values. slices ( ) ;
61-
62- // SAFETY: We checked above that the selection mask has the same length as the
63- // buffer.
64- let new_len = unsafe { filter_slices_in_place ( self . as_mut_slice ( ) , slices) } ;
65-
66- debug_assert ! (
67- new_len <= self . len( ) ,
68- "The new length was somehow larger after filter"
69- ) ;
70-
71- // Truncate the buffer to the new length.
72- // SAFETY: The new length cannot be larger than the old length, so all values must
73- // be initialized.
74- unsafe { self . set_len ( new_len) } ;
75- }
76- }
77- }
78- }
79-
80- impl < T : Copy > Filter < MaskIndices < ' _ > > for & mut BufferMut < T > {
81- type Output = ( ) ;
82-
83- fn filter ( self , indices : & MaskIndices ) -> Self :: Output {
84- for ( write_index, & read_index) in indices. iter ( ) . enumerate ( ) {
85- self [ write_index] = self [ read_index] ;
86- }
63+ impl < const NB : usize , T : Copy > Filter < BitView < ' _ , NB > > for & Buffer < T > {
64+ type Output = Buffer < T > ;
8765
88- self . truncate ( indices. len ( ) ) ;
66+ fn filter ( self , selection : & BitView < ' _ , NB > ) -> Self :: Output {
67+ // TODO(ngates): this is very very slow!
68+ let elems = self . as_slice ( ) ;
69+ let mut out = BufferMut :: < T > :: with_capacity ( selection. true_count ( ) ) ;
70+ selection. iter_ones ( |idx| {
71+ unsafe { out. push_unchecked ( elems[ idx] ) } ;
72+ } ) ;
73+ out. freeze ( )
8974 }
9075}
9176
92- impl < M , T : Copy > Filter < M > for Buffer < T >
77+ impl < M , T > Filter < M > for & mut BufferMut < T >
9378where
94- for < ' a > & ' a Buffer < T > : Filter < M , Output = Buffer < T > > ,
95- for < ' a > & ' a mut BufferMut < T > : Filter < M , Output = ( ) > ,
79+ for < ' a > & ' a mut [ T ] : Filter < M , Output = & ' a mut [ T ] > ,
9680{
97- type Output = Self ;
81+ type Output = ( ) ;
9882
99- fn filter ( self , selection_mask : & M ) -> Self {
100- // If we have exclusive access, we can perform the filter in place.
101- match self . try_into_mut ( ) {
102- Ok ( mut buffer_mut) => {
103- ( & mut buffer_mut) . filter ( selection_mask) ;
104- buffer_mut. freeze ( )
105- }
106- // Otherwise, allocate a new buffer and fill it in (delegate to the `&Buffer` impl).
107- Err ( buffer) => ( & buffer) . filter ( selection_mask) ,
108- }
83+ fn filter ( self , selection_mask : & M ) -> Self :: Output {
84+ let true_count = self . as_mut_slice ( ) . filter ( selection_mask) . len ( ) ;
85+ self . truncate ( true_count) ;
10986 }
11087}
11188
@@ -121,42 +98,9 @@ fn filter_slices<T>(values: &[T], output_len: usize, slices: &[(usize, usize)])
12198 out. freeze ( )
12299}
123100
124- /// Filters a buffer in-place using slice ranges to determine which values to keep.
125- ///
126- /// Returns the new length of the buffer.
127- ///
128- /// # Safety
129- ///
130- /// The slice ranges must be in the range of the `buffer`.
131- #[ must_use = "The caller should set the new length of the buffer" ]
132- unsafe fn filter_slices_in_place < T : Copy > ( buffer : & mut [ T ] , slices : & [ ( usize , usize ) ] ) -> usize {
133- let mut write_pos = 0 ;
134-
135- // For each range in the selection, copy all of the elements to the current write position.
136- for & ( start, end) in slices {
137- // Note that we could add an if statement here that checks `if read_idx != write_idx`, but
138- // it's probably better to just avoid the branch misprediction.
139-
140- let len = end - start;
141-
142- // SAFETY: The safety contract enforces that all ranges are within bounds.
143- unsafe {
144- core:: ptr:: copy (
145- buffer. as_ptr ( ) . add ( start) ,
146- buffer. as_mut_ptr ( ) . add ( write_pos) ,
147- len,
148- )
149- } ;
150-
151- write_pos += len;
152- }
153-
154- write_pos
155- }
156-
157101#[ cfg( test) ]
158102mod tests {
159- use vortex_buffer:: buffer;
103+ use vortex_buffer:: { BufferMut , buffer, buffer_mut } ;
160104 use vortex_mask:: Mask ;
161105
162106 use super :: * ;
@@ -202,8 +146,6 @@ mod tests {
202146 assert_eq ! ( result, buffer![ 1u32 , 2 , 5 ] ) ;
203147 }
204148
205- use vortex_buffer:: { BufferMut , buffer_mut} ;
206-
207149 #[ test]
208150 fn test_filter_all_true ( ) {
209151 let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
0 commit comments