Skip to content

Commit 1ccf92b

Browse files
committed
optimize
Signed-off-by: Alexander Droste <[email protected]>
1 parent 570b252 commit 1ccf92b

File tree

2 files changed

+244
-38
lines changed

2 files changed

+244
-38
lines changed

vortex-compute/src/expand/buffer.rs

Lines changed: 240 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_buffer::Buffer;
5-
use vortex_mask::{Mask, MaskValues};
4+
use vortex_buffer::{Buffer, BufferMut};
5+
use vortex_mask::Mask;
66

77
use crate::expand::Expand;
88

@@ -19,7 +19,17 @@ impl<T: Copy> Expand for Buffer<T> {
1919
match mask {
2020
Mask::AllTrue(_) => self,
2121
Mask::AllFalse(_) => Buffer::empty(),
22-
Mask::Values(mask_values) => expand_indices(self, mask_values),
22+
Mask::Values(_) => {
23+
// Try to get exclusive access to expand in-place.
24+
match self.try_into_mut() {
25+
Ok(mut buf_mut) => {
26+
(&mut buf_mut).expand(mask);
27+
buf_mut.freeze()
28+
}
29+
// Otherwise, expand into a new buffer at the target size.
30+
Err(buffer) => expand_into_new_buffer(buffer.as_slice(), mask),
31+
}
32+
}
2333
}
2434
}
2535
}
@@ -37,53 +47,72 @@ impl<T: Copy> Expand for &Buffer<T> {
3747
match mask {
3848
Mask::AllTrue(_) => self.clone(),
3949
Mask::AllFalse(_) => Buffer::empty(),
40-
Mask::Values(mask_values) => expand_indices(self.clone(), mask_values),
50+
// Expand into new buffer unconditionally as `try_into_mut` can never succeed on `&Buffer`.
51+
Mask::Values(_) => expand_into_new_buffer(self.as_slice(), mask),
4152
}
4253
}
4354
}
4455

45-
/// Expands a buffer by placing its elements at positions marked as `true` in the mask.
46-
///
47-
/// # Arguments
48-
///
49-
/// * `buf` - The buffer containing elements to scatter
50-
/// * `mask_values` - The mask indicating where elements should be placed
51-
///
52-
/// # Panics
53-
///
54-
/// Panics if the number of `true` values in the mask does not equal the buffer length.
55-
fn expand_indices<T: Copy>(buf: Buffer<T>, mask_values: &MaskValues) -> Buffer<T> {
56-
let buf_len = buf.len();
56+
impl<T: Copy> Expand for &mut BufferMut<T> {
57+
type Output = ();
58+
59+
fn expand(self, mask: &Mask) {
60+
assert_eq!(
61+
mask.true_count(),
62+
self.len(),
63+
"Expand mask true count must equal the buffer length"
64+
);
5765

58-
assert_eq!(
59-
mask_values.true_count(),
60-
buf_len,
61-
"Mask true count must equal buffer length"
62-
);
66+
match mask {
67+
Mask::AllTrue(_) => {}
68+
Mask::AllFalse(_) => self.clear(),
69+
Mask::Values(mask_values) => {
70+
let buf_len = self.len();
71+
let mask_len = mask_values.len();
6372

64-
if buf.is_empty() {
65-
return Buffer::empty();
66-
}
73+
if buf_len == 0 {
74+
return;
75+
}
6776

68-
let mut buf_mut = buf.into_mut();
69-
let mask_len = mask_values.len();
70-
buf_mut.reserve(mask_len - buf_len);
77+
// Expand to the new buffer size which equals the length of the mask.
78+
self.reserve(mask_len - buf_len);
7179

72-
// Expand to the new buffer size which is equals the length of the mask.
73-
unsafe {
74-
buf_mut.set_len(mask_len);
80+
// SAFETY: We just reserved enough space above.
81+
unsafe {
82+
self.set_len(mask_len);
83+
}
84+
85+
let buf_slice = self.as_mut_slice();
86+
scatter_into_slice(buf_slice, buf_len, mask_values);
87+
}
88+
}
7589
}
90+
}
7691

77-
let buf_slice = buf_mut.as_mut_slice();
78-
let mut element_idx = buf_len;
92+
/// Scatters elements from a mutable slice into itself at positions marked true in the mask.
93+
/// Used for in-place expansion where source and destination are the same buffer.
94+
///
95+
/// # Arguments
96+
///
97+
/// * `buf_slice` - The buffer slice to scatter into (already expanded to mask length)
98+
/// * `src_len` - The original length of the buffer before expansion
99+
/// * `mask_values` - The mask indicating where elements should be placed
100+
fn scatter_into_slice<T: Copy>(
101+
buf_slice: &mut [T],
102+
src_len: usize,
103+
mask_values: &vortex_mask::MaskValues,
104+
) {
105+
let mask_len = buf_slice.len();
79106

80107
// Pick the first value as a default value. The buffer is not empty, and we
81108
// know that the first value is guaranteed to be initialized. By doing this
82-
// T does does not require to implement `Default`.
109+
// T does not require to implement `Default`.
83110
let pseudo_default_value = buf_slice[0];
84111

112+
let mut element_idx = src_len;
113+
85114
// Iterate backwards through the mask to avoid overwriting unprocessed elements.
86-
for mask_idx in (buf_len..mask_len).rev() {
115+
for mask_idx in (src_len..mask_len).rev() {
87116
if mask_values.value(mask_idx) {
88117
element_idx -= 1;
89118
buf_slice[mask_idx] = buf_slice[element_idx];
@@ -93,20 +122,97 @@ fn expand_indices<T: Copy>(buf: Buffer<T>, mask_values: &MaskValues) -> Buffer<T
93122
}
94123
}
95124

96-
for mask_idx in (0..buf_len).rev() {
125+
for mask_idx in (0..src_len).rev() {
97126
if mask_values.value(mask_idx) {
98127
element_idx -= 1;
99128
buf_slice[mask_idx] = buf_slice[element_idx];
100129
}
101130
// For the range up to buffer length, all positions are already initialized.
102131
}
132+
}
103133

104-
buf_mut.freeze()
134+
/// Scatters elements from a source buffer into a destination slice at positions marked true
135+
/// in the mask.
136+
///
137+
/// # Arguments
138+
///
139+
/// * `dest` - The destination buffer slice (already expanded to mask length)
140+
/// * `src` - The source elements to scatter
141+
/// * `src_len` - The length of the source buffer
142+
/// * `mask_values` - The mask indicating where elements should be placed
143+
fn scatter_into_slice_from<T: Copy>(
144+
dest: &mut [T],
145+
src: &[T],
146+
src_len: usize,
147+
mask_values: &vortex_mask::MaskValues,
148+
) {
149+
let mask_len = dest.len();
150+
151+
// Pick the first value as a default value. The source buffer is not empty.
152+
let pseudo_default_value = src[0];
153+
154+
let mut element_idx = src_len;
155+
156+
// Iterate backwards through the mask to avoid any issues.
157+
for mask_idx in (src_len..mask_len).rev() {
158+
if mask_values.value(mask_idx) {
159+
element_idx -= 1;
160+
dest[mask_idx] = src[element_idx];
161+
} else {
162+
// Initialize with a pseudo-default value.
163+
dest[mask_idx] = pseudo_default_value;
164+
}
165+
}
166+
167+
for mask_idx in (0..src_len).rev() {
168+
if mask_values.value(mask_idx) {
169+
element_idx -= 1;
170+
dest[mask_idx] = src[element_idx];
171+
}
172+
}
173+
}
174+
175+
/// Expands a slice into a new buffer at the target size, scattering elements to
176+
/// true positions in the mask.
177+
///
178+
/// # Arguments
179+
///
180+
/// * `src` - The source slice containing elements to scatter
181+
/// * `mask` - The mask indicating where elements should be placed
182+
///
183+
/// # Returns
184+
///
185+
/// A new `Buffer<T>` with length equal to `mask.len()`, with elements from `src` scattered
186+
/// to positions marked true in the mask. Positions marked false can have arbitrary values.
187+
fn expand_into_new_buffer<T: Copy>(src: &[T], mask: &Mask) -> Buffer<T> {
188+
let src_len = src.len();
189+
let mask_len = mask.len();
190+
191+
match mask {
192+
Mask::AllTrue(_) => Buffer::from_trusted_len_iter(src.iter().copied()),
193+
Mask::AllFalse(_) => Buffer::empty(),
194+
Mask::Values(mask_values) => {
195+
if src_len == 0 {
196+
return Buffer::empty();
197+
}
198+
199+
let mut buf_mut = BufferMut::<T>::with_capacity(mask_len);
200+
201+
// SAFETY: We're preallocating the full target capacity.
202+
unsafe {
203+
buf_mut.set_len(mask_len);
204+
}
205+
206+
let buf_slice = buf_mut.as_mut_slice();
207+
scatter_into_slice_from(buf_slice, src, src_len, mask_values);
208+
buf_mut.freeze()
209+
}
210+
}
105211
}
106212

107213
#[cfg(test)]
108214
mod tests {
109-
use vortex_buffer::buffer;
215+
use vortex_buffer::{buffer, buffer_mut};
110216
use vortex_mask::Mask;
111217

112218
use super::*;
@@ -172,4 +278,101 @@ mod tests {
172278
let mask = Mask::from_iter([true, true, true, false]);
173279
buf.expand(&mask);
174280
}
281+
282+
// Tests for &Buffer<T> impl
283+
#[test]
284+
fn test_expand_ref_scattered() {
285+
let buf = buffer![100u32, 200, 300];
286+
let mask = Mask::from_iter([true, false, true, false, true]);
287+
288+
let result = (&buf).expand(&mask);
289+
assert_eq!(result.len(), 5);
290+
assert_eq!(result.as_slice()[0], 100);
291+
assert_eq!(result.as_slice()[2], 200);
292+
assert_eq!(result.as_slice()[4], 300);
293+
}
294+
295+
#[test]
296+
fn test_expand_ref_all_true() {
297+
let buf = buffer![10u32, 20, 30];
298+
let mask = Mask::new_true(3);
299+
300+
let result = (&buf).expand(&mask);
301+
assert_eq!(result, buffer![10u32, 20, 30]);
302+
}
303+
304+
// Tests for &mut BufferMut<T> impl
305+
#[test]
306+
fn test_expand_mut_scattered() {
307+
let mut buf = buffer_mut![100u32, 200, 300];
308+
let mask = Mask::from_iter([true, false, true, false, true]);
309+
310+
(&mut buf).expand(&mask);
311+
assert_eq!(buf.len(), 5);
312+
assert_eq!(buf.as_slice()[0], 100);
313+
assert_eq!(buf.as_slice()[2], 200);
314+
assert_eq!(buf.as_slice()[4], 300);
315+
}
316+
317+
#[test]
318+
fn test_expand_mut_all_true() {
319+
let mut buf = buffer_mut![10u32, 20, 30];
320+
let mask = Mask::new_true(3);
321+
322+
(&mut buf).expand(&mask);
323+
assert_eq!(buf.as_slice(), &[10, 20, 30]);
324+
}
325+
326+
#[test]
327+
fn test_expand_mut_all_false() {
328+
let mut buf: BufferMut<u32> = BufferMut::with_capacity(0);
329+
let mask = Mask::new_false(0);
330+
331+
(&mut buf).expand(&mask);
332+
assert!(buf.is_empty());
333+
}
334+
335+
#[test]
336+
fn test_expand_mut_contiguous_start() {
337+
let mut buf = buffer_mut![10u32, 20, 30, 40];
338+
let mask = Mask::from_iter([true, true, true, true, false, false, false]);
339+
340+
(&mut buf).expand(&mask);
341+
assert_eq!(buf.len(), 7);
342+
assert_eq!(buf.as_slice()[0..4], [10u32, 20, 30, 40]);
343+
}
344+
345+
#[test]
346+
fn test_expand_mut_contiguous_end() {
347+
let mut buf = buffer_mut![100u32, 200, 300];
348+
let mask = Mask::from_iter([false, false, false, false, true, true, true]);
349+
350+
(&mut buf).expand(&mask);
351+
assert_eq!(buf.len(), 7);
352+
assert_eq!(buf.as_slice()[4..7], [100u32, 200, 300]);
353+
}
354+
355+
#[test]
356+
fn test_expand_mut_dense() {
357+
let mut buf = buffer_mut![1u32, 2, 3, 4, 5];
358+
let mask = Mask::from_iter([
359+
true, false, true, true, false, true, true, false, false, false,
360+
]);
361+
362+
(&mut buf).expand(&mask);
363+
assert_eq!(buf.len(), 10);
364+
assert_eq!(buf.as_slice()[0], 1);
365+
assert_eq!(buf.as_slice()[2], 2);
366+
assert_eq!(buf.as_slice()[3], 3);
367+
assert_eq!(buf.as_slice()[5], 4);
368+
assert_eq!(buf.as_slice()[6], 5);
369+
}
370+
371+
#[test]
372+
#[should_panic(expected = "Expand mask true count must equal the buffer length")]
373+
fn test_expand_mut_mismatch_true_count() {
374+
let mut buf = buffer_mut![10u32, 20];
375+
let mask = Mask::from_iter([true, true, true, false]);
376+
(&mut buf).expand(&mask);
377+
}
175378
}

vortex-compute/src/expand/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ pub trait Expand {
1616
///
1717
///
1818
/// The result will have length equal to the mask. All values of `self` are
19-
/// then scattered to the true positions of the mask.
19+
/// then scattered to the true positions of the mask. False positions can have
20+
/// any value that `Output` allows for. No assumption can be made that false
21+
/// positions are set the default value of `Output`.
22+
///
2023
///
2124
/// # Panics
2225
///

0 commit comments

Comments
 (0)