Skip to content

Commit d3a1932

Browse files
committed
perf: bit iter opt
Signed-off-by: Alexander Droste <[email protected]>
1 parent ed532f6 commit d3a1932

File tree

2 files changed

+117
-32
lines changed

2 files changed

+117
-32
lines changed

vortex-compute/benches/expand_buffer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn expand_inplace<T: Copy + Default + From<u8> + Send + 'static>(
8989
(buffer, mask)
9090
})
9191
.bench_values(|(mut buffer, mask)| {
92-
let result = buffer.expand(&mask);
93-
divan::black_box(result);
92+
(&mut buffer).expand(&mask);
93+
divan::black_box(buffer);
9494
});
9595
}

vortex-compute/src/expand/buffer.rs

Lines changed: 115 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::ops::Range;
45
use vortex_buffer::{Buffer, BufferMut};
56
use vortex_mask::{Mask, MaskValues};
67

@@ -96,37 +97,17 @@ fn expand_inplace<T: Copy>(buf_mut: &mut BufferMut<T>, mask_values: &MaskValues)
9697
let pseudo_default_value = buf_slice[0];
9798

9899
let mut element_idx = buf_len;
100+
let bit_buffer = mask_values.bit_buffer();
99101

100102
// Iterate backwards through the mask to avoid overwriting unprocessed elements.
101-
for (mask_idx, is_valid) in mask_values
102-
.bit_buffer()
103-
.slice(buf_len..)
104-
.iter()
105-
.rev()
106-
.enumerate()
107-
{
103+
iter_bits_reverse(bit_buffer, 0..mask_len, |idx, is_valid| {
108104
if is_valid {
109105
element_idx -= 1;
110-
unsafe { *buf_slice.get_unchecked_mut(mask_idx) = buf_slice[element_idx] };
106+
unsafe { *buf_slice.get_unchecked_mut(idx) = buf_slice[element_idx] };
111107
} else {
112-
// Initialize with a pseudo-default value.
113-
unsafe { *buf_slice.get_unchecked_mut(mask_idx) = pseudo_default_value };
108+
unsafe { *buf_slice.get_unchecked_mut(idx) = pseudo_default_value };
114109
}
115-
}
116-
117-
for (mask_idx, is_valid) in mask_values
118-
.bit_buffer()
119-
.slice(..buf_len)
120-
.iter()
121-
.rev()
122-
.enumerate()
123-
{
124-
if is_valid {
125-
element_idx -= 1;
126-
unsafe { *buf_slice.get_unchecked_mut(mask_idx) = buf_slice[element_idx] };
127-
}
128-
// For the range up to buffer length, all positions are already initialized.
129-
}
110+
});
130111
}
131112

132113
/// Expands a slice into a new buffer at the target size, scattering elements to
@@ -151,7 +132,9 @@ fn expand_copy<T: Copy>(src: &[T], mask_values: &MaskValues) -> Buffer<T> {
151132
let pseudo_default_value = src[0];
152133
let mut element_idx = 0;
153134

154-
for (mask_idx, is_valid) in mask_values.bit_buffer().iter().enumerate() {
135+
let bit_buffer = mask_values.bit_buffer();
136+
137+
iter_bits(bit_buffer, 0..mask_len, |mask_idx, is_valid| {
155138
if is_valid {
156139
unsafe {
157140
target_slice
@@ -160,22 +143,124 @@ fn expand_copy<T: Copy>(src: &[T], mask_values: &MaskValues) -> Buffer<T> {
160143
};
161144
element_idx += 1;
162145
} else {
163-
// Initialize with a pseudo-default value. In case we expand
164-
// into a new buffer all false positions need to be initialized.
165146
unsafe {
166147
target_slice
167148
.get_unchecked_mut(mask_idx)
168-
.write(pseudo_default_value)
149+
.write(pseudo_default_value);
169150
};
170151
}
171-
}
152+
});
172153

173154
// SAFETY: Buffer has sufficient capacity and all elements have been initialized.
174155
unsafe { target_buf.set_len(mask_len) };
175156

176157
target_buf.freeze()
177158
}
178159

160+
/// Iterate through bits in a buffer.
161+
///
162+
/// # Arguments
163+
///
164+
/// * `bit_buffer` - The bit buffer to iterate through
165+
/// * `range` - Bit range to iterate through
166+
/// * `f` - Callback function taking (bit_index, is_set)
167+
///
168+
/// # Safety
169+
///
170+
/// The caller must ensure that the range is within valid bounds of the bit buffer.
171+
#[inline]
172+
fn iter_bits<F>(bit_buffer: &vortex_buffer::BitBuffer, range: Range<usize>, mut f: F)
173+
where
174+
F: FnMut(usize, bool),
175+
{
176+
let start = range.start;
177+
let end = range.end;
178+
179+
assert!(start <= end);
180+
assert!(end <= bit_buffer.len());
181+
182+
let buffer_ptr = bit_buffer.inner().as_ptr();
183+
let offset = bit_buffer.offset();
184+
185+
let full_bytes = (end - start) / 8;
186+
let remaining_bits = (end - start) % 8;
187+
188+
for byte_idx in 0..full_bytes {
189+
let bit_offset = offset + start + byte_idx * 8;
190+
let byte_offset = bit_offset / 8;
191+
let byte = unsafe { *buffer_ptr.add(byte_offset) };
192+
193+
for bit_idx in 0..8 {
194+
let is_set = (byte & (1 << bit_idx)) != 0;
195+
f(start + byte_idx * 8 + bit_idx, is_set);
196+
}
197+
}
198+
199+
if remaining_bits > 0 {
200+
let bit_idx_start = start + full_bytes * 8;
201+
let bit_offset = offset + bit_idx_start;
202+
let byte_offset = bit_offset / 8;
203+
let byte = unsafe { *buffer_ptr.add(byte_offset) };
204+
205+
for i in 0..remaining_bits {
206+
let is_set = (byte & (1 << i)) != 0;
207+
f(bit_idx_start + i, is_set);
208+
}
209+
}
210+
}
211+
212+
/// Iterate through bits in a buffer in reverse.
213+
///
214+
/// # Arguments
215+
///
216+
/// * `bit_buffer` - The bit buffer to iterate through
217+
/// * `range` - Bit range to iterate through in reverse (start inclusive, end exclusive)
218+
/// * `f` - Callback function taking (bit_index, is_set)
219+
///
220+
/// # Safety
221+
///
222+
/// The caller must ensure that the range is within valid bounds of the bit buffer.
223+
#[inline]
224+
fn iter_bits_reverse<F>(bit_buffer: &vortex_buffer::BitBuffer, range: Range<usize>, mut f: F)
225+
where
226+
F: FnMut(usize, bool),
227+
{
228+
let start = range.start;
229+
let end = range.end;
230+
231+
assert!(start <= end);
232+
assert!(end <= bit_buffer.len());
233+
234+
let buffer_ptr = bit_buffer.inner().as_ptr();
235+
let offset = bit_buffer.offset();
236+
237+
let full_bytes = (end - start) / 8;
238+
let remaining_bits = (end - start) % 8;
239+
240+
if remaining_bits > 0 {
241+
let bit_idx_start = start + full_bytes * 8;
242+
let bit_offset = offset + bit_idx_start;
243+
let byte_offset = bit_offset / 8;
244+
let byte = unsafe { *buffer_ptr.add(byte_offset) };
245+
246+
for bit_idx in (0..remaining_bits).rev() {
247+
let is_set = (byte & (1 << bit_idx)) != 0;
248+
f(bit_idx_start + bit_idx, is_set);
249+
}
250+
}
251+
252+
for byte_idx in (0..full_bytes).rev() {
253+
let bit_offset = offset + start + byte_idx * 8;
254+
let byte_offset = bit_offset / 8;
255+
let byte = unsafe { *buffer_ptr.add(byte_offset) };
256+
257+
for bit_idx in (0..8).rev() {
258+
let is_set = (byte & (1 << bit_idx)) != 0;
259+
f(start + byte_idx * 8 + bit_idx, is_set);
260+
}
261+
}
262+
}
263+
179264
#[cfg(test)]
180265
mod tests {
181266
use vortex_buffer::{buffer, buffer_mut};

0 commit comments

Comments
 (0)