@@ -32,6 +32,43 @@ impl<T: Copy> Filter for &Buffer<T> {
3232 }
3333}
3434
35+ impl < T : Copy > Filter for & mut BufferMut < T > {
36+ type Output = ( ) ;
37+
38+ fn filter ( self , selection_mask : & Mask ) {
39+ assert_eq ! (
40+ selection_mask. len( ) ,
41+ self . len( ) ,
42+ "Selection mask length must equal the buffer length"
43+ ) ;
44+
45+ match selection_mask {
46+ Mask :: AllTrue ( _) => { }
47+ Mask :: AllFalse ( _) => self . clear ( ) ,
48+ Mask :: Values ( values) => {
49+ // We choose to _always_ use slices here because iterating over indices will have
50+ // strictly more loop iterations than slices, and the overhead over batched
51+ // `ptr::copy(len)` is not worth it.
52+ let slices = values. slices ( ) ;
53+
54+ // SAFETY: We checked above that the selection mask has the same length as the
55+ // buffer.
56+ let new_len = unsafe { filter_slices_in_place ( self . as_mut_slice ( ) , slices) } ;
57+
58+ debug_assert ! (
59+ new_len <= self . len( ) ,
60+ "The new length was somehow larger after filter"
61+ ) ;
62+
63+ // Truncate the buffer to the new length.
64+ // SAFETY: The new length cannot be larger than the old length, so all values must
65+ // be initialized.
66+ unsafe { self . set_len ( new_len) } ;
67+ }
68+ }
69+ }
70+ }
71+
3572impl < T : Copy > Filter for Buffer < T > {
3673 type Output = Self ;
3774
@@ -66,6 +103,39 @@ fn filter_slices<T>(values: &[T], output_len: usize, slices: &[(usize, usize)])
66103 out. freeze ( )
67104}
68105
106+ /// Filters a buffer in-place using slice ranges to determine which values to keep.
107+ ///
108+ /// Returns the new length of the buffer.
109+ ///
110+ /// # Safety
111+ ///
112+ /// The slice ranges must be in the range of the `buffer`.
113+ #[ must_use = "The caller should set the new length of the buffer" ]
114+ unsafe fn filter_slices_in_place < T : Copy > ( buffer : & mut [ T ] , slices : & [ ( usize , usize ) ] ) -> usize {
115+ let mut write_pos = 0 ;
116+
117+ // For each range in the selection, copy all of the elements to the current write position.
118+ for & ( start, end) in slices {
119+ // Note that we could add an if statement here that checks `if read_idx != write_idx`, but
120+ // it's probably better to just avoid the branch misprediction.
121+
122+ let len = end - start;
123+
124+ // SAFETY: The safety contract enforces that all ranges are within bounds.
125+ unsafe {
126+ core:: ptr:: copy (
127+ buffer. as_ptr ( ) . add ( start) ,
128+ buffer. as_mut_ptr ( ) . add ( write_pos) ,
129+ len,
130+ )
131+ } ;
132+
133+ write_pos += len;
134+ }
135+
136+ write_pos
137+ }
138+
69139#[ cfg( test) ]
70140mod tests {
71141 use vortex_buffer:: buffer;
@@ -113,4 +183,112 @@ mod tests {
113183 let result = filter_slices ( buf. as_slice ( ) , 3 , & [ ( 0 , 2 ) , ( 4 , 5 ) ] ) ;
114184 assert_eq ! ( result, buffer![ 1u32 , 2 , 5 ] ) ;
115185 }
186+
187+ use vortex_buffer:: { BufferMut , buffer_mut} ;
188+
189+ #[ test]
190+ fn test_filter_all_true ( ) {
191+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
192+ let mask = Mask :: new_true ( 5 ) ;
193+
194+ buf. filter ( & mask) ;
195+ assert_eq ! ( buf. as_slice( ) , & [ 1 , 2 , 3 , 4 , 5 ] ) ;
196+ }
197+
198+ #[ test]
199+ fn test_filter_all_false ( ) {
200+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
201+ let mask = Mask :: new_false ( 5 ) ;
202+
203+ buf. filter ( & mask) ;
204+ assert ! ( buf. is_empty( ) ) ;
205+ }
206+
207+ #[ test]
208+ fn test_filter_sparse ( ) {
209+ let mut buf = buffer_mut ! [ 10u32 , 20 , 30 , 40 , 50 ] ;
210+ // Select indices 0, 2, 4 (sparse selection).
211+ let mask = Mask :: from_iter ( [ true , false , true , false , true ] ) ;
212+
213+ buf. filter ( & mask) ;
214+ assert_eq ! ( buf. as_slice( ) , & [ 10 , 30 , 50 ] ) ;
215+ }
216+
217+ #[ test]
218+ fn test_filter_dense ( ) {
219+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ] ;
220+ // Dense selection (80% selected).
221+ let mask = Mask :: from_iter ( [ true , true , true , true , false , true , true , true , false , true ] ) ;
222+
223+ buf. filter ( & mask) ;
224+ assert_eq ! ( buf. as_slice( ) , & [ 1 , 2 , 3 , 4 , 6 , 7 , 8 , 10 ] ) ;
225+ }
226+
227+ #[ test]
228+ fn test_filter_single_element_kept ( ) {
229+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
230+ let mask = Mask :: from_iter ( [ false , false , true , false , false ] ) ;
231+
232+ buf. filter ( & mask) ;
233+ assert_eq ! ( buf. as_slice( ) , & [ 3 ] ) ;
234+ }
235+
236+ #[ test]
237+ fn test_filter_first_last ( ) {
238+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
239+ let mask = Mask :: from_iter ( [ true , false , false , false , true ] ) ;
240+
241+ buf. filter ( & mask) ;
242+ assert_eq ! ( buf. as_slice( ) , & [ 1 , 5 ] ) ;
243+ }
244+
245+ #[ test]
246+ fn test_filter_alternating ( ) {
247+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 , 6 ] ;
248+ let mask = Mask :: from_iter ( [ true , false , true , false , true , false ] ) ;
249+
250+ buf. filter ( & mask) ;
251+ assert_eq ! ( buf. as_slice( ) , & [ 1 , 3 , 5 ] ) ;
252+ }
253+
254+ #[ test]
255+ fn test_filter_empty_buffer ( ) {
256+ let mut buf: BufferMut < u32 > = BufferMut :: with_capacity ( 0 ) ;
257+ let mask = Mask :: new_false ( 0 ) ;
258+
259+ buf. filter ( & mask) ;
260+ assert ! ( buf. is_empty( ) ) ;
261+ }
262+
263+ #[ test]
264+ fn test_filter_contiguous_regions ( ) {
265+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ] ;
266+ // Two contiguous regions: [0..3] and [7..10].
267+ let mask = Mask :: from_iter ( [
268+ true , true , true , false , false , false , false , true , true , true ,
269+ ] ) ;
270+
271+ buf. filter ( & mask) ;
272+ assert_eq ! ( buf. as_slice( ) , & [ 1 , 2 , 3 , 8 , 9 , 10 ] ) ;
273+ }
274+
275+ #[ test]
276+ fn test_filter_large_buffer ( ) {
277+ let mut buf: BufferMut < u32 > = BufferMut :: from_iter ( 0 ..1000 ) ;
278+ // Keep every third element.
279+ let mask = Mask :: from_iter ( ( 0 ..1000 ) . map ( |i| i % 3 == 0 ) ) ;
280+
281+ buf. filter ( & mask) ;
282+ let expected: Vec < u32 > = ( 0 ..1000 ) . filter ( |i| i % 3 == 0 ) . collect ( ) ;
283+ assert_eq ! ( buf. as_slice( ) , & expected[ ..] ) ;
284+ }
285+
286+ #[ test]
287+ #[ should_panic( expected = "Selection mask length must equal the buffer length" ) ]
288+ fn test_filter_length_mismatch ( ) {
289+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 ] ;
290+ let mask = Mask :: new_true ( 5 ) ; // Wrong length.
291+
292+ buf. filter ( & mask) ;
293+ }
116294}
0 commit comments