11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use vortex_buffer:: Buffer ;
5- use vortex_mask:: { Mask , MaskValues } ;
4+ use vortex_buffer:: { Buffer , BufferMut } ;
5+ use vortex_mask:: Mask ;
66
77use crate :: expand:: Expand ;
88
@@ -19,7 +19,17 @@ impl<T: Copy> Expand for Buffer<T> {
1919 match mask {
2020 Mask :: AllTrue ( _) => self ,
2121 Mask :: AllFalse ( _) => Buffer :: empty ( ) ,
22- Mask :: Values ( mask_values) => expand_indices ( self , mask_values) ,
22+ Mask :: Values ( _) => {
23+ // Try to get exclusive access to expand in-place.
24+ match self . try_into_mut ( ) {
25+ Ok ( mut buf_mut) => {
26+ ( & mut buf_mut) . expand ( mask) ;
27+ buf_mut. freeze ( )
28+ }
29+ // Otherwise, expand into a new buffer at the target size.
30+ Err ( buffer) => expand_into_new_buffer ( buffer. as_slice ( ) , mask) ,
31+ }
32+ }
2333 }
2434 }
2535}
@@ -37,53 +47,72 @@ impl<T: Copy> Expand for &Buffer<T> {
3747 match mask {
3848 Mask :: AllTrue ( _) => self . clone ( ) ,
3949 Mask :: AllFalse ( _) => Buffer :: empty ( ) ,
40- Mask :: Values ( mask_values) => expand_indices ( self . clone ( ) , mask_values) ,
50+ // Expand into new buffer unconditionally as `try_into_mut` can never succeed on `&Buffer`.
51+ Mask :: Values ( _) => expand_into_new_buffer ( self . as_slice ( ) , mask) ,
4152 }
4253 }
4354}
4455
45- /// Expands a buffer by placing its elements at positions marked as `true` in the mask.
46- ///
47- /// # Arguments
48- ///
49- /// * `buf` - The buffer containing elements to scatter
50- /// * `mask_values` - The mask indicating where elements should be placed
51- ///
52- /// # Panics
53- ///
54- /// Panics if the number of `true` values in the mask does not equal the buffer length.
55- fn expand_indices < T : Copy > ( buf : Buffer < T > , mask_values : & MaskValues ) -> Buffer < T > {
56- let buf_len = buf. len ( ) ;
56+ impl < T : Copy > Expand for & mut BufferMut < T > {
57+ type Output = ( ) ;
58+
59+ fn expand ( self , mask : & Mask ) {
60+ assert_eq ! (
61+ mask. true_count( ) ,
62+ self . len( ) ,
63+ "Expand mask true count must equal the buffer length"
64+ ) ;
5765
58- assert_eq ! (
59- mask_values. true_count( ) ,
60- buf_len,
61- "Mask true count must equal buffer length"
62- ) ;
66+ match mask {
67+ Mask :: AllTrue ( _) => { }
68+ Mask :: AllFalse ( _) => self . clear ( ) ,
69+ Mask :: Values ( mask_values) => {
70+ let buf_len = self . len ( ) ;
71+ let mask_len = mask_values. len ( ) ;
6372
64- if buf . is_empty ( ) {
65- return Buffer :: empty ( ) ;
66- }
73+ if buf_len == 0 {
74+ return ;
75+ }
6776
68- let mut buf_mut = buf. into_mut ( ) ;
69- let mask_len = mask_values. len ( ) ;
70- buf_mut. reserve ( mask_len - buf_len) ;
77+ // Expand to the new buffer size which equals the length of the mask.
78+ self . reserve ( mask_len - buf_len) ;
7179
72- // Expand to the new buffer size which is equals the length of the mask.
73- unsafe {
74- buf_mut. set_len ( mask_len) ;
80+ // SAFETY: We just reserved enough space above.
81+ unsafe {
82+ self . set_len ( mask_len) ;
83+ }
84+
85+ let buf_slice = self . as_mut_slice ( ) ;
86+ scatter_into_slice ( buf_slice, buf_len, mask_values) ;
87+ }
88+ }
7589 }
90+ }
7691
77- let buf_slice = buf_mut. as_mut_slice ( ) ;
78- let mut element_idx = buf_len;
92+ /// Scatters elements from a mutable slice into itself at positions marked true in the mask.
93+ /// Used for in-place expansion where source and destination are the same buffer.
94+ ///
95+ /// # Arguments
96+ ///
97+ /// * `buf_slice` - The buffer slice to scatter into (already expanded to mask length)
98+ /// * `src_len` - The original length of the buffer before expansion
99+ /// * `mask_values` - The mask indicating where elements should be placed
100+ fn scatter_into_slice < T : Copy > (
101+ buf_slice : & mut [ T ] ,
102+ src_len : usize ,
103+ mask_values : & vortex_mask:: MaskValues ,
104+ ) {
105+ let mask_len = buf_slice. len ( ) ;
79106
80107 // Pick the first value as a default value. The buffer is not empty, and we
81108 // know that the first value is guaranteed to be initialized. By doing this
82- // T does does not require to implement `Default`.
109+ // T does not require to implement `Default`.
83110 let pseudo_default_value = buf_slice[ 0 ] ;
84111
112+ let mut element_idx = src_len;
113+
85114 // Iterate backwards through the mask to avoid overwriting unprocessed elements.
86- for mask_idx in ( buf_len ..mask_len) . rev ( ) {
115+ for mask_idx in ( src_len ..mask_len) . rev ( ) {
87116 if mask_values. value ( mask_idx) {
88117 element_idx -= 1 ;
89118 buf_slice[ mask_idx] = buf_slice[ element_idx] ;
@@ -93,20 +122,97 @@ fn expand_indices<T: Copy>(buf: Buffer<T>, mask_values: &MaskValues) -> Buffer<T
93122 }
94123 }
95124
96- for mask_idx in ( 0 ..buf_len ) . rev ( ) {
125+ for mask_idx in ( 0 ..src_len ) . rev ( ) {
97126 if mask_values. value ( mask_idx) {
98127 element_idx -= 1 ;
99128 buf_slice[ mask_idx] = buf_slice[ element_idx] ;
100129 }
101130 // For the range up to buffer length, all positions are already initialized.
102131 }
132+ }
103133
104- buf_mut. freeze ( )
134+ /// Scatters elements from a source buffer into a destination slice at positions marked true
135+ /// in the mask.
136+ ///
137+ /// # Arguments
138+ ///
139+ /// * `dest` - The destination buffer slice (already expanded to mask length)
140+ /// * `src` - The source elements to scatter
141+ /// * `src_len` - The length of the source buffer
142+ /// * `mask_values` - The mask indicating where elements should be placed
143+ fn scatter_into_slice_from < T : Copy > (
144+ dest : & mut [ T ] ,
145+ src : & [ T ] ,
146+ src_len : usize ,
147+ mask_values : & vortex_mask:: MaskValues ,
148+ ) {
149+ let mask_len = dest. len ( ) ;
150+
151+ // Pick the first value as a default value. The source buffer is not empty.
152+ let pseudo_default_value = src[ 0 ] ;
153+
154+ let mut element_idx = src_len;
155+
156+ // Iterate backwards through the mask to avoid any issues.
157+ for mask_idx in ( src_len..mask_len) . rev ( ) {
158+ if mask_values. value ( mask_idx) {
159+ element_idx -= 1 ;
160+ dest[ mask_idx] = src[ element_idx] ;
161+ } else {
162+ // Initialize with a pseudo-default value.
163+ dest[ mask_idx] = pseudo_default_value;
164+ }
165+ }
166+
167+ for mask_idx in ( 0 ..src_len) . rev ( ) {
168+ if mask_values. value ( mask_idx) {
169+ element_idx -= 1 ;
170+ dest[ mask_idx] = src[ element_idx] ;
171+ }
172+ }
173+ }
174+
175+ /// Expands a slice into a new buffer at the target size, scattering elements to
176+ /// true positions in the mask.
177+ ///
178+ /// # Arguments
179+ ///
180+ /// * `src` - The source slice containing elements to scatter
181+ /// * `mask` - The mask indicating where elements should be placed
182+ ///
183+ /// # Returns
184+ ///
185+ /// A new `Buffer<T>` with length equal to `mask.len()`, with elements from `src` scattered
186+ /// to positions marked true in the mask. Positions marked false can have arbitrary values.
187+ fn expand_into_new_buffer < T : Copy > ( src : & [ T ] , mask : & Mask ) -> Buffer < T > {
188+ let src_len = src. len ( ) ;
189+ let mask_len = mask. len ( ) ;
190+
191+ match mask {
192+ Mask :: AllTrue ( _) => Buffer :: from_trusted_len_iter ( src. iter ( ) . copied ( ) ) ,
193+ Mask :: AllFalse ( _) => Buffer :: empty ( ) ,
194+ Mask :: Values ( mask_values) => {
195+ if src_len == 0 {
196+ return Buffer :: empty ( ) ;
197+ }
198+
199+ let mut buf_mut = BufferMut :: < T > :: with_capacity ( mask_len) ;
200+
201+ // SAFETY: We're preallocating the full target capacity.
202+ unsafe {
203+ buf_mut. set_len ( mask_len) ;
204+ }
205+
206+ let buf_slice = buf_mut. as_mut_slice ( ) ;
207+ scatter_into_slice_from ( buf_slice, src, src_len, mask_values) ;
208+ buf_mut. freeze ( )
209+ }
210+ }
105211}
106212
107213#[ cfg( test) ]
108214mod tests {
109- use vortex_buffer:: buffer;
215+ use vortex_buffer:: { buffer, buffer_mut } ;
110216 use vortex_mask:: Mask ;
111217
112218 use super :: * ;
@@ -172,4 +278,101 @@ mod tests {
172278 let mask = Mask :: from_iter ( [ true , true , true , false ] ) ;
173279 buf. expand ( & mask) ;
174280 }
281+
282+ // Tests for &Buffer<T> impl
283+ #[ test]
284+ fn test_expand_ref_scattered ( ) {
285+ let buf = buffer ! [ 100u32 , 200 , 300 ] ;
286+ let mask = Mask :: from_iter ( [ true , false , true , false , true ] ) ;
287+
288+ let result = ( & buf) . expand ( & mask) ;
289+ assert_eq ! ( result. len( ) , 5 ) ;
290+ assert_eq ! ( result. as_slice( ) [ 0 ] , 100 ) ;
291+ assert_eq ! ( result. as_slice( ) [ 2 ] , 200 ) ;
292+ assert_eq ! ( result. as_slice( ) [ 4 ] , 300 ) ;
293+ }
294+
295+ #[ test]
296+ fn test_expand_ref_all_true ( ) {
297+ let buf = buffer ! [ 10u32 , 20 , 30 ] ;
298+ let mask = Mask :: new_true ( 3 ) ;
299+
300+ let result = ( & buf) . expand ( & mask) ;
301+ assert_eq ! ( result, buffer![ 10u32 , 20 , 30 ] ) ;
302+ }
303+
304+ // Tests for &mut BufferMut<T> impl
305+ #[ test]
306+ fn test_expand_mut_scattered ( ) {
307+ let mut buf = buffer_mut ! [ 100u32 , 200 , 300 ] ;
308+ let mask = Mask :: from_iter ( [ true , false , true , false , true ] ) ;
309+
310+ ( & mut buf) . expand ( & mask) ;
311+ assert_eq ! ( buf. len( ) , 5 ) ;
312+ assert_eq ! ( buf. as_slice( ) [ 0 ] , 100 ) ;
313+ assert_eq ! ( buf. as_slice( ) [ 2 ] , 200 ) ;
314+ assert_eq ! ( buf. as_slice( ) [ 4 ] , 300 ) ;
315+ }
316+
317+ #[ test]
318+ fn test_expand_mut_all_true ( ) {
319+ let mut buf = buffer_mut ! [ 10u32 , 20 , 30 ] ;
320+ let mask = Mask :: new_true ( 3 ) ;
321+
322+ ( & mut buf) . expand ( & mask) ;
323+ assert_eq ! ( buf. as_slice( ) , & [ 10 , 20 , 30 ] ) ;
324+ }
325+
326+ #[ test]
327+ fn test_expand_mut_all_false ( ) {
328+ let mut buf: BufferMut < u32 > = BufferMut :: with_capacity ( 0 ) ;
329+ let mask = Mask :: new_false ( 0 ) ;
330+
331+ ( & mut buf) . expand ( & mask) ;
332+ assert ! ( buf. is_empty( ) ) ;
333+ }
334+
335+ #[ test]
336+ fn test_expand_mut_contiguous_start ( ) {
337+ let mut buf = buffer_mut ! [ 10u32 , 20 , 30 , 40 ] ;
338+ let mask = Mask :: from_iter ( [ true , true , true , true , false , false , false ] ) ;
339+
340+ ( & mut buf) . expand ( & mask) ;
341+ assert_eq ! ( buf. len( ) , 7 ) ;
342+ assert_eq ! ( buf. as_slice( ) [ 0 ..4 ] , [ 10u32 , 20 , 30 , 40 ] ) ;
343+ }
344+
345+ #[ test]
346+ fn test_expand_mut_contiguous_end ( ) {
347+ let mut buf = buffer_mut ! [ 100u32 , 200 , 300 ] ;
348+ let mask = Mask :: from_iter ( [ false , false , false , false , true , true , true ] ) ;
349+
350+ ( & mut buf) . expand ( & mask) ;
351+ assert_eq ! ( buf. len( ) , 7 ) ;
352+ assert_eq ! ( buf. as_slice( ) [ 4 ..7 ] , [ 100u32 , 200 , 300 ] ) ;
353+ }
354+
355+ #[ test]
356+ fn test_expand_mut_dense ( ) {
357+ let mut buf = buffer_mut ! [ 1u32 , 2 , 3 , 4 , 5 ] ;
358+ let mask = Mask :: from_iter ( [
359+ true , false , true , true , false , true , true , false , false , false ,
360+ ] ) ;
361+
362+ ( & mut buf) . expand ( & mask) ;
363+ assert_eq ! ( buf. len( ) , 10 ) ;
364+ assert_eq ! ( buf. as_slice( ) [ 0 ] , 1 ) ;
365+ assert_eq ! ( buf. as_slice( ) [ 2 ] , 2 ) ;
366+ assert_eq ! ( buf. as_slice( ) [ 3 ] , 3 ) ;
367+ assert_eq ! ( buf. as_slice( ) [ 5 ] , 4 ) ;
368+ assert_eq ! ( buf. as_slice( ) [ 6 ] , 5 ) ;
369+ }
370+
371+ #[ test]
372+ #[ should_panic( expected = "Expand mask true count must equal the buffer length" ) ]
373+ fn test_expand_mut_mismatch_true_count ( ) {
374+ let mut buf = buffer_mut ! [ 10u32 , 20 ] ;
375+ let mask = Mask :: from_iter ( [ true , true , true , false ] ) ;
376+ ( & mut buf) . expand ( & mask) ;
377+ }
175378}
0 commit comments