11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use std:: ops:: Range ;
45use vortex_buffer:: { Buffer , BufferMut } ;
56use vortex_mask:: { Mask , MaskValues } ;
67
@@ -96,37 +97,17 @@ fn expand_inplace<T: Copy>(buf_mut: &mut BufferMut<T>, mask_values: &MaskValues)
9697 let pseudo_default_value = buf_slice[ 0 ] ;
9798
9899 let mut element_idx = buf_len;
100+ let bit_buffer = mask_values. bit_buffer ( ) ;
99101
100102 // Iterate backwards through the mask to avoid overwriting unprocessed elements.
101- for ( mask_idx, is_valid) in mask_values
102- . bit_buffer ( )
103- . slice ( buf_len..)
104- . iter ( )
105- . rev ( )
106- . enumerate ( )
107- {
103+ iter_bits_reverse ( bit_buffer, 0 ..mask_len, |idx, is_valid| {
108104 if is_valid {
109105 element_idx -= 1 ;
110- unsafe { * buf_slice. get_unchecked_mut ( mask_idx ) = buf_slice[ element_idx] } ;
106+ unsafe { * buf_slice. get_unchecked_mut ( idx ) = buf_slice[ element_idx] } ;
111107 } else {
112- // Initialize with a pseudo-default value.
113- unsafe { * buf_slice. get_unchecked_mut ( mask_idx) = pseudo_default_value } ;
108+ unsafe { * buf_slice. get_unchecked_mut ( idx) = pseudo_default_value } ;
114109 }
115- }
116-
117- for ( mask_idx, is_valid) in mask_values
118- . bit_buffer ( )
119- . slice ( ..buf_len)
120- . iter ( )
121- . rev ( )
122- . enumerate ( )
123- {
124- if is_valid {
125- element_idx -= 1 ;
126- unsafe { * buf_slice. get_unchecked_mut ( mask_idx) = buf_slice[ element_idx] } ;
127- }
128- // For the range up to buffer length, all positions are already initialized.
129- }
110+ } ) ;
130111}
131112
132113/// Expands a slice into a new buffer at the target size, scattering elements to
@@ -151,7 +132,9 @@ fn expand_copy<T: Copy>(src: &[T], mask_values: &MaskValues) -> Buffer<T> {
151132 let pseudo_default_value = src[ 0 ] ;
152133 let mut element_idx = 0 ;
153134
154- for ( mask_idx, is_valid) in mask_values. bit_buffer ( ) . iter ( ) . enumerate ( ) {
135+ let bit_buffer = mask_values. bit_buffer ( ) ;
136+
137+ iter_bits ( bit_buffer, 0 ..mask_len, |mask_idx, is_valid| {
155138 if is_valid {
156139 unsafe {
157140 target_slice
@@ -160,22 +143,124 @@ fn expand_copy<T: Copy>(src: &[T], mask_values: &MaskValues) -> Buffer<T> {
160143 } ;
161144 element_idx += 1 ;
162145 } else {
163- // Initialize with a pseudo-default value. In case we expand
164- // into a new buffer all false positions need to be initialized.
165146 unsafe {
166147 target_slice
167148 . get_unchecked_mut ( mask_idx)
168- . write ( pseudo_default_value)
149+ . write ( pseudo_default_value) ;
169150 } ;
170151 }
171- }
152+ } ) ;
172153
173154 // SAFETY: Buffer has sufficient capacity and all elements have been initialized.
174155 unsafe { target_buf. set_len ( mask_len) } ;
175156
176157 target_buf. freeze ( )
177158}
178159
160+ /// Iterate through bits in a buffer.
161+ ///
162+ /// # Arguments
163+ ///
164+ /// * `bit_buffer` - The bit buffer to iterate through
165+ /// * `range` - Bit range to iterate through
166+ /// * `f` - Callback function taking (bit_index, is_set)
167+ ///
168+ /// # Safety
169+ ///
170+ /// The caller must ensure that the range is within valid bounds of the bit buffer.
171+ #[ inline]
172+ fn iter_bits < F > ( bit_buffer : & vortex_buffer:: BitBuffer , range : Range < usize > , mut f : F )
173+ where
174+ F : FnMut ( usize , bool ) ,
175+ {
176+ let start = range. start ;
177+ let end = range. end ;
178+
179+ assert ! ( start <= end) ;
180+ assert ! ( end <= bit_buffer. len( ) ) ;
181+
182+ let buffer_ptr = bit_buffer. inner ( ) . as_ptr ( ) ;
183+ let offset = bit_buffer. offset ( ) ;
184+
185+ let full_bytes = ( end - start) / 8 ;
186+ let remaining_bits = ( end - start) % 8 ;
187+
188+ for byte_idx in 0 ..full_bytes {
189+ let bit_offset = offset + start + byte_idx * 8 ;
190+ let byte_offset = bit_offset / 8 ;
191+ let byte = unsafe { * buffer_ptr. add ( byte_offset) } ;
192+
193+ for bit_idx in 0 ..8 {
194+ let is_set = ( byte & ( 1 << bit_idx) ) != 0 ;
195+ f ( start + byte_idx * 8 + bit_idx, is_set) ;
196+ }
197+ }
198+
199+ if remaining_bits > 0 {
200+ let bit_idx_start = start + full_bytes * 8 ;
201+ let bit_offset = offset + bit_idx_start;
202+ let byte_offset = bit_offset / 8 ;
203+ let byte = unsafe { * buffer_ptr. add ( byte_offset) } ;
204+
205+ for i in 0 ..remaining_bits {
206+ let is_set = ( byte & ( 1 << i) ) != 0 ;
207+ f ( bit_idx_start + i, is_set) ;
208+ }
209+ }
210+ }
211+
212+ /// Iterate through bits in a buffer in reverse.
213+ ///
214+ /// # Arguments
215+ ///
216+ /// * `bit_buffer` - The bit buffer to iterate through
217+ /// * `range` - Bit range to iterate through in reverse (start inclusive, end exclusive)
218+ /// * `f` - Callback function taking (bit_index, is_set)
219+ ///
220+ /// # Safety
221+ ///
222+ /// The caller must ensure that the range is within valid bounds of the bit buffer.
223+ #[ inline]
224+ fn iter_bits_reverse < F > ( bit_buffer : & vortex_buffer:: BitBuffer , range : Range < usize > , mut f : F )
225+ where
226+ F : FnMut ( usize , bool ) ,
227+ {
228+ let start = range. start ;
229+ let end = range. end ;
230+
231+ assert ! ( start <= end) ;
232+ assert ! ( end <= bit_buffer. len( ) ) ;
233+
234+ let buffer_ptr = bit_buffer. inner ( ) . as_ptr ( ) ;
235+ let offset = bit_buffer. offset ( ) ;
236+
237+ let full_bytes = ( end - start) / 8 ;
238+ let remaining_bits = ( end - start) % 8 ;
239+
240+ if remaining_bits > 0 {
241+ let bit_idx_start = start + full_bytes * 8 ;
242+ let bit_offset = offset + bit_idx_start;
243+ let byte_offset = bit_offset / 8 ;
244+ let byte = unsafe { * buffer_ptr. add ( byte_offset) } ;
245+
246+ for bit_idx in ( 0 ..remaining_bits) . rev ( ) {
247+ let is_set = ( byte & ( 1 << bit_idx) ) != 0 ;
248+ f ( bit_idx_start + bit_idx, is_set) ;
249+ }
250+ }
251+
252+ for byte_idx in ( 0 ..full_bytes) . rev ( ) {
253+ let bit_offset = offset + start + byte_idx * 8 ;
254+ let byte_offset = bit_offset / 8 ;
255+ let byte = unsafe { * buffer_ptr. add ( byte_offset) } ;
256+
257+ for bit_idx in ( 0 ..8 ) . rev ( ) {
258+ let is_set = ( byte & ( 1 << bit_idx) ) != 0 ;
259+ f ( start + byte_idx * 8 + bit_idx, is_set) ;
260+ }
261+ }
262+ }
263+
179264#[ cfg( test) ]
180265mod tests {
181266 use vortex_buffer:: { buffer, buffer_mut} ;
0 commit comments