Skip to content

Commit 8bdaf7a

Browse files
committed
add AVX512 support for filtering in place
Signed-off-by: Connor Tsui <[email protected]>
1 parent cff218e commit 8bdaf7a

File tree

4 files changed

+607
-0
lines changed

4 files changed

+607
-0
lines changed
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
5+
6+
use std::arch::x86_64::*;
7+
8+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9+
use crate::filter::avx512::SimdCompress;
10+
11+
/// Filter a mutable slice of elements in-place depending on the given mask.
12+
///
13+
/// The mask is represented as a slice of bytes (LSB is the first element).
14+
///
15+
/// Returns the true count of the mask.
16+
///
17+
/// This function automatically dispatches to the most efficient implementation based on the
18+
/// available CPU features at compile time.
19+
///
20+
/// # Panics
21+
///
22+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
23+
#[inline]
24+
pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
25+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
26+
{
27+
let use_simd = if T::WIDTH >= 32 {
28+
// 32-bit and 64-bit types only need AVX-512F.
29+
is_x86_feature_detected!("avx512f")
30+
} else {
31+
// 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2.
32+
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2")
33+
};
34+
35+
if use_simd {
36+
return unsafe { filter_in_place_avx512(data, mask) };
37+
}
38+
}
39+
40+
// Fall back to scalar implementation for non-x86 or when SIMD not available.
41+
filter_in_place_scalar(data, mask)
42+
}
43+
44+
/// Filter a mutable slice of elements in-place depending on the given mask.
45+
///
46+
/// The mask is represented as a slice of bytes (LSB is the first element).
47+
///
48+
/// Returns the true count of the mask.
49+
///
50+
/// This function uses a scalar implementation that simply uses a read and write pointer to write
51+
/// values to the correct places in memory.
52+
///
53+
/// # Panics
54+
///
55+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
56+
#[inline]
57+
pub fn filter_in_place_scalar<T: Copy>(data: &mut [T], mask: &[u8]) -> usize {
58+
assert_eq!(
59+
mask.len(),
60+
data.len().div_ceil(8),
61+
"Mask length must be data.len().div_ceil(8)"
62+
);
63+
64+
let mut write_pos = 0;
65+
let data_len = data.len();
66+
67+
for read_pos in 0..data_len {
68+
let byte_idx = read_pos / 8;
69+
let bit_idx = read_pos % 8;
70+
71+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
72+
data[write_pos] = data[read_pos];
73+
write_pos += 1;
74+
}
75+
}
76+
77+
write_pos
78+
}
79+
80+
/// Filter a mutable slice of elements in-place depending on the given mask.
81+
///
82+
/// The mask is represented as a slice of bytes (LSB is the first element).
83+
///
84+
/// Returns the true count of the mask.
85+
///
86+
/// This function uses AVX-512 SIMD instructions for high-performance filtering.
87+
///
88+
/// # Panics
89+
///
90+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
91+
///
92+
/// # Safety
93+
///
94+
/// This function requires the appropriate SIMD instruction set to be available.
95+
/// For AVX-512F types, the CPU must support AVX-512F.
96+
/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2.
97+
#[inline]
98+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99+
#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")]
100+
pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
101+
assert_eq!(
102+
mask.len(),
103+
data.len().div_ceil(8),
104+
"Mask length must be data.len().div_ceil(8)"
105+
);
106+
107+
let data_len = data.len();
108+
let mut write_pos = 0;
109+
110+
// Pre-calculate loop bounds to eliminate branch misprediction in the hot loop.
111+
let full_chunks = data_len / T::ELEMENTS_PER_VECTOR;
112+
let remainder = data_len % T::ELEMENTS_PER_VECTOR;
113+
114+
// Process full chunks with no branches in the loop.
115+
for chunk_idx in 0..full_chunks {
116+
let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR;
117+
let mask_byte_offset = chunk_idx * T::MASK_BYTES;
118+
119+
// Read the mask for this chunk.
120+
// SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks.
121+
let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) };
122+
123+
// Load elements into the SIMD register.
124+
// SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks.
125+
let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) };
126+
127+
// Moves all elements that have their bit set to 1 in the mask value to the left.
128+
let filtered = unsafe { T::compress_vector(mask_value, vector) };
129+
130+
// Write the filtered result vector back to memory.
131+
// SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting.
132+
unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) };
133+
134+
// Uses the hardware `popcnt` instruction if available.
135+
let count = T::count_ones(mask_value);
136+
write_pos += count;
137+
}
138+
139+
// Handle the final partial chunk with simple scalar processing.
140+
let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR;
141+
for i in 0..remainder {
142+
let read_idx = read_pos + i;
143+
let bit_idx = read_idx % 8;
144+
let byte_idx = read_idx / 8;
145+
146+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
147+
data[write_pos] = data[read_idx];
148+
write_pos += 1;
149+
}
150+
}
151+
152+
write_pos
153+
}
154+
155+
#[cfg(test)]
156+
mod tests {
157+
use super::*;
158+
159+
fn create_mask(bits: &[bool]) -> Vec<u8> {
160+
let mut mask = vec![0u8; bits.len().div_ceil(8)];
161+
for (i, &bit) in bits.iter().enumerate() {
162+
if bit {
163+
mask[i / 8] |= 1 << (i % 8);
164+
}
165+
}
166+
mask
167+
}
168+
169+
fn test_implementation<F>(filter_fn: F)
170+
where
171+
F: Fn(&mut [i32], &[u8]) -> usize,
172+
{
173+
// Test 1: Small array - all elements pass
174+
let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7];
175+
let mask = vec![0xFF]; // All 1s
176+
let count = filter_fn(&mut data, &mask);
177+
assert_eq!(count, 8);
178+
assert_eq!(&data[..8], &[0, 1, 2, 3, 4, 5, 6, 7]);
179+
180+
// Test 2: Small array - no elements pass
181+
let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7];
182+
let mask = vec![0x00]; // All 0s
183+
let count = filter_fn(&mut data, &mask);
184+
assert_eq!(count, 0);
185+
186+
// Test 3: Small array - every other element
187+
let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7];
188+
let mask = vec![0x55]; // 01010101
189+
let count = filter_fn(&mut data, &mask);
190+
assert_eq!(count, 4);
191+
assert_eq!(&data[..4], &[0, 2, 4, 6]);
192+
193+
// Test 4: 16 elements - all pass
194+
let mut data: Vec<i32> = (0..16).collect();
195+
let mask = vec![0xFF, 0xFF];
196+
let count = filter_fn(&mut data, &mask);
197+
assert_eq!(count, 16);
198+
assert_eq!(
199+
&data[..16],
200+
&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
201+
);
202+
203+
// Test 5: 16 elements - alternating pattern
204+
let mut data: Vec<i32> = (0..16).collect();
205+
let mask = vec![0xAA, 0xAA]; // 10101010 10101010
206+
let count = filter_fn(&mut data, &mask);
207+
assert_eq!(count, 8);
208+
assert_eq!(&data[..8], &[1, 3, 5, 7, 9, 11, 13, 15]);
209+
210+
// Test 6: Larger array (32 elements)
211+
let mut data: Vec<i32> = (0..32).collect();
212+
let mask = vec![0xFF, 0x00, 0xFF, 0x00]; // First and third bytes
213+
let count = filter_fn(&mut data, &mask);
214+
assert_eq!(count, 16);
215+
assert_eq!(
216+
&data[..16],
217+
&[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]
218+
);
219+
220+
// Test 7: Non-aligned size (23 elements)
221+
let mut data: Vec<i32> = (0..23).collect();
222+
let mask = create_mask(&[
223+
true, false, true, false, true, false, true, false, // byte 0
224+
false, true, false, true, false, true, false, true, // byte 1
225+
true, true, false, false, true, true, false, // byte 2 (partial)
226+
]);
227+
let count = filter_fn(&mut data, &mask);
228+
assert_eq!(count, 12);
229+
assert_eq!(&data[..12], &[0, 2, 4, 6, 9, 11, 13, 15, 16, 17, 20, 21]);
230+
}
231+
232+
#[test]
233+
fn test_scalar() {
234+
test_implementation(filter_in_place_scalar::<i32>);
235+
}
236+
237+
#[test]
238+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
239+
fn test_avx512() {
240+
test_implementation(|data, mask| unsafe { filter_in_place_avx512::<i32>(data, mask) });
241+
}
242+
243+
#[test]
244+
fn test_runtime_dispatch() {
245+
test_implementation(filter_in_place::<i32>);
246+
}
247+
248+
#[test]
249+
fn test_all_implementations_match() {
250+
// Test that all available implementations produce the same results
251+
252+
// Test various sizes and patterns
253+
let test_cases = vec![
254+
(8, vec![0xAA]), // 8 elements, alternating
255+
(16, vec![0xFF, 0xFF]), // 16 elements, all pass
256+
(16, vec![0x00, 0x00]), // 16 elements, none pass
257+
(32, vec![0x55, 0x55, 0x55, 0x55]), // 32 elements, alternating
258+
(24, vec![0xFF, 0x00, 0xFF]), // 24 elements, mixed
259+
(100, vec![0xFF; 13]), // 100 elements (needs 13 bytes)
260+
];
261+
262+
for (size, mask) in test_cases {
263+
let mut data_scalar: Vec<i32> = (0..size).collect();
264+
265+
let count_scalar = filter_in_place_scalar::<i32>(&mut data_scalar, &mask);
266+
267+
// Test AVX-512 on x86/x86_64
268+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269+
{
270+
let mut data_avx512: Vec<i32> = (0..size).collect();
271+
let count_avx512 =
272+
unsafe { filter_in_place_avx512::<i32>(&mut data_avx512, &mask) };
273+
assert_eq!(
274+
count_scalar, count_avx512,
275+
"Count mismatch for size {}",
276+
size
277+
);
278+
assert_eq!(
279+
&data_scalar[..count_scalar],
280+
&data_avx512[..count_avx512],
281+
"Data mismatch for size {}",
282+
size
283+
);
284+
}
285+
}
286+
}
287+
288+
#[expect(clippy::cast_possible_truncation)]
289+
#[test]
290+
fn test_large_arrays() {
291+
// Test with very large arrays to ensure chunking works correctly
292+
let sizes: Vec<usize> = vec![1024, 1000, 2048, 4096, 10000];
293+
294+
for size in sizes {
295+
let mut data: Vec<i32> = (0..size as i32).collect();
296+
// Create alternating mask
297+
let mut mask = vec![0u8; size.div_ceil(8)];
298+
mask.fill(0x55); // 01010101
299+
300+
let count = filter_in_place::<i32>(&mut data, &mask);
301+
assert_eq!(count, size / 2);
302+
303+
// Verify first few and last few elements
304+
(0..10.min(count)).for_each(|i| {
305+
assert_eq!(data[i], (i * 2) as i32);
306+
});
307+
}
308+
}
309+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
// TODO(connor): Refactor this module and add filter not in place + up front scalar fallback.
5+
6+
#![expect(unused)] // TODO(connor): Remove
7+
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8+
9+
mod in_place;
10+
pub use in_place::*;
11+
12+
mod simd_compress;
13+
pub use simd_compress::SimdCompress;

0 commit comments

Comments
 (0)