Skip to content

Commit 5c3322a

Browse files
committed
address comments
Signed-off-by: Alexander Droste <[email protected]>
1 parent a885487 commit 5c3322a

File tree

1 file changed

+40
-98
lines changed

1 file changed

+40
-98
lines changed

vortex-compute/src/expand/buffer.rs

Lines changed: 40 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_buffer::{Buffer, BufferMut};
5-
use vortex_mask::Mask;
5+
use vortex_mask::{Mask, MaskValues};
66

77
use crate::expand::Expand;
88

@@ -19,15 +19,15 @@ impl<T: Copy> Expand for Buffer<T> {
1919
match mask {
2020
Mask::AllTrue(_) => self,
2121
Mask::AllFalse(_) => Buffer::empty(),
22-
Mask::Values(_) => {
22+
Mask::Values(mask_values) => {
2323
// Try to get exclusive access to expand in-place.
2424
match self.try_into_mut() {
2525
Ok(mut buf_mut) => {
2626
(&mut buf_mut).expand(mask);
2727
buf_mut.freeze()
2828
}
2929
// Otherwise, expand into a new buffer.
30-
Err(buffer) => expand_into_new_buffer(buffer.as_slice(), mask),
30+
Err(buffer) => expand_copy(buffer.as_slice(), mask_values),
3131
}
3232
}
3333
}
@@ -46,9 +46,9 @@ impl<T: Copy> Expand for &Buffer<T> {
4646

4747
match mask {
4848
Mask::AllTrue(_) => self.clone(),
49-
Mask::AllFalse(_) => Buffer::empty(),
49+
Mask::AllFalse(_) => self.clone(),
5050
// Expand into new buffer unconditionally as `try_into_mut` can never succeed on `&Buffer`.
51-
Mask::Values(_) => expand_into_new_buffer(self.as_slice(), mask),
51+
Mask::Values(mask_values) => expand_copy(self.as_slice(), mask_values),
5252
}
5353
}
5454
}
@@ -65,24 +65,18 @@ impl<T: Copy> Expand for &mut BufferMut<T> {
6565

6666
match mask {
6767
Mask::AllTrue(_) => {}
68-
Mask::AllFalse(_) => self.clear(),
68+
Mask::AllFalse(_) => {}
6969
Mask::Values(mask_values) => {
7070
let buf_len = self.len();
7171
let mask_len = mask_values.len();
7272

73-
if buf_len == 0 {
74-
return;
75-
}
76-
7773
self.reserve(mask_len - buf_len);
7874

79-
// SAFETY: We just reserved enough space above.
80-
unsafe {
81-
self.set_len(mask_len);
82-
}
75+
// SAFETY: Sufficient capacity has been reserved.
76+
unsafe { self.set_len(mask_len) };
8377

8478
let buf_slice = self.as_mut_slice();
85-
expand_into_slice_inplace(buf_slice, buf_len, mask_values);
79+
expand_inplace(buf_slice, buf_len, mask_values);
8680
}
8781
}
8882
}
@@ -96,11 +90,7 @@ impl<T: Copy> Expand for &mut BufferMut<T> {
9690
/// * `buf_slice` - The buffer slice to scatter into (already expanded to mask length)
9791
/// * `src_len` - The original length of the buffer before expansion
9892
/// * `mask_values` - The mask indicating where elements should be placed
99-
fn expand_into_slice_inplace<T: Copy>(
100-
buf_slice: &mut [T],
101-
src_len: usize,
102-
mask_values: &vortex_mask::MaskValues,
103-
) {
93+
fn expand_inplace<T: Copy>(buf_slice: &mut [T], src_len: usize, mask_values: &MaskValues) {
10494
let mask_len = buf_slice.len();
10595

10696
// Pick the first value as a default value. The buffer is not empty, and we
@@ -112,6 +102,7 @@ fn expand_into_slice_inplace<T: Copy>(
112102

113103
// Iterate backwards through the mask to avoid overwriting unprocessed elements.
114104
for mask_idx in (src_len..mask_len).rev() {
105+
// NOTE(0ax1): .value is slow => optimize
115106
if mask_values.value(mask_idx) {
116107
element_idx -= 1;
117108
buf_slice[mask_idx] = buf_slice[element_idx];
@@ -130,20 +121,27 @@ fn expand_into_slice_inplace<T: Copy>(
130121
}
131122
}
132123

133-
/// Scatters elements from a source buffer into a destination slice at positions marked true
134-
/// in the mask.
124+
/// Expands a slice into a new buffer at the target size, scattering elements to
125+
/// true positions in the mask.
135126
///
136127
/// # Arguments
137128
///
138-
/// * `src` - The source elements to scatter
139-
/// * `dest` - The destination buffer slice (already expanded to mask length)
129+
/// * `src` - The source slice containing elements to scatter
140130
/// * `mask_values` - The mask indicating where elements should be placed
141-
fn scatter_into_slice_from<T: Copy>(
142-
src: &[T],
143-
dest: &mut [T],
144-
mask_values: &vortex_mask::MaskValues,
145-
) {
146-
let mask_len = dest.len();
131+
///
132+
/// # Returns
133+
///
134+
/// A new `Buffer<T>` with length equal to `mask.len`, with elements from `src` scattered
135+
/// to positions marked true in the mask. Positions marked false can have arbitrary values.
136+
fn expand_copy<T: Copy>(src: &[T], mask_values: &MaskValues) -> Buffer<T> {
137+
let mask_len = mask_values.len();
138+
139+
let mut target_buf = BufferMut::<T>::with_capacity(mask_len);
140+
141+
// SAFETY: Preallocate full target capacity.
142+
unsafe { target_buf.set_len(mask_len) };
143+
144+
let buf_slice = target_buf.as_mut_slice();
147145

148146
// Pick the first value as a default value. The source buffer is not empty.
149147
let pseudo_default_value = src[0];
@@ -152,60 +150,19 @@ fn scatter_into_slice_from<T: Copy>(
152150
let mut element_idx = src_len;
153151

154152
// Iterate backwards through the mask to avoid any issues.
155-
for mask_idx in (src_len..mask_len).rev() {
153+
for mask_idx in (0..mask_len).rev() {
154+
// NOTE(0ax1): .value is slow => optimize
156155
if mask_values.value(mask_idx) {
157156
element_idx -= 1;
158-
dest[mask_idx] = src[element_idx];
157+
buf_slice[mask_idx] = src[element_idx];
159158
} else {
160-
// Initialize with a pseudo-default value.
161-
dest[mask_idx] = pseudo_default_value;
162-
}
163-
}
164-
165-
for mask_idx in (0..src_len).rev() {
166-
if mask_values.value(mask_idx) {
167-
element_idx -= 1;
168-
dest[mask_idx] = src[element_idx];
159+
// Initialize with a pseudo-default value. In case we expand into a
160+
// new buffer all false positions need to be initialized.
161+
buf_slice[mask_idx] = pseudo_default_value;
169162
}
170163
}
171-
}
172164

173-
/// Expands a slice into a new buffer at the target size, scattering elements to
174-
/// true positions in the mask.
175-
///
176-
/// # Arguments
177-
///
178-
/// * `src` - The source slice containing elements to scatter
179-
/// * `mask` - The mask indicating where elements should be placed
180-
///
181-
/// # Returns
182-
///
183-
/// A new `Buffer<T>` with length equal to `mask.len()`, with elements from `src` scattered
184-
/// to positions marked true in the mask. Positions marked false can have arbitrary values.
185-
fn expand_into_new_buffer<T: Copy>(src: &[T], mask: &Mask) -> Buffer<T> {
186-
let src_len = src.len();
187-
let mask_len = mask.len();
188-
189-
match mask {
190-
Mask::AllTrue(_) => Buffer::from_trusted_len_iter(src.iter().copied()),
191-
Mask::AllFalse(_) => Buffer::empty(),
192-
Mask::Values(mask_values) => {
193-
if src_len == 0 {
194-
return Buffer::empty();
195-
}
196-
197-
let mut buf_mut = BufferMut::<T>::with_capacity(mask_len);
198-
199-
// SAFETY: We're preallocating the full target capacity.
200-
unsafe {
201-
buf_mut.set_len(mask_len);
202-
}
203-
204-
let buf_slice = buf_mut.as_mut_slice();
205-
scatter_into_slice_from(src, buf_slice, mask_values);
206-
buf_mut.freeze()
207-
}
208-
}
165+
target_buf.freeze()
209166
}
210167

211168
#[cfg(test)]
@@ -383,31 +340,16 @@ mod tests {
383340
true, false, true, false, true, false, true, false, true, false,
384341
]);
385342

386-
let result = expand_into_new_buffer(&src, &mask);
343+
let Mask::Values(mask_values) = mask else {
344+
panic!("Expected Mask::Values");
345+
};
346+
347+
let result = expand_copy(&src, &mask_values);
387348
assert_eq!(result.len(), 10);
388349
assert_eq!(result.as_slice()[0], 10);
389350
assert_eq!(result.as_slice()[2], 20);
390351
assert_eq!(result.as_slice()[4], 30);
391352
assert_eq!(result.as_slice()[6], 40);
392353
assert_eq!(result.as_slice()[8], 50);
393354
}
394-
395-
#[test]
396-
fn test_scatter_into_slice_from() {
397-
let src = [1u32, 2, 3, 4, 5];
398-
let mut dest = vec![0u32; 8];
399-
let mask = Mask::from_iter([true, true, false, true, true, false, true, false]);
400-
401-
let mask_values = match &mask {
402-
Mask::Values(mv) => mv,
403-
_ => panic!("Expected Mask::Values"),
404-
};
405-
406-
scatter_into_slice_from(&src, &mut dest, mask_values);
407-
assert_eq!(dest[0], 1);
408-
assert_eq!(dest[1], 2);
409-
assert_eq!(dest[3], 3);
410-
assert_eq!(dest[4], 4);
411-
assert_eq!(dest[6], 5);
412-
}
413355
}

0 commit comments

Comments
 (0)