Skip to content

Commit c5162cc

Browse files
committed
add AVX512 support for filtering out of place
Signed-off-by: Connor Tsui <[email protected]>
1 parent 9229453 commit c5162cc

File tree

5 files changed

+384
-22
lines changed

5 files changed

+384
-22
lines changed

vortex-compute/benches/avx512.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ use rand::Rng;
88
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99
use vortex_compute::filter::slice::in_place::avx512::filter_in_place_avx512;
1010
use vortex_compute::filter::slice::in_place::filter_in_place_scalar;
11+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
12+
use vortex_compute::filter::slice::out::avx512::filter_into_avx512;
13+
use vortex_compute::filter::slice::out::filter_into_scalar;
1114

1215
fn main() {
1316
divan::main();
@@ -32,46 +35,53 @@ fn create_random_mask(size: usize, probability: f64) -> Vec<u8> {
3235
mask
3336
}
3437

35-
// Benchmark different data sizes.
36-
const SIZES: &[usize] = &[1 << 10, 1 << 14, 1 << 17];
38+
/// Benchmark different data sizes.
39+
const SIZES: &[usize] = &[1 << 10, 1 << 11, 1 << 14, 1 << 17];
3740

38-
// Different probability values to benchmark.
41+
/// Different probability values to benchmark.
3942
const PROBABILITIES: &[f64] = &[0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0];
4043

41-
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
42-
fn random_probability_scalar(bencher: divan::Bencher, (size, probability): (usize, f64)) {
44+
/// The number of samples per benchmark.
45+
const SAMPLE_SIZE: u32 = 64;
46+
47+
#[divan::bench(sample_size = SAMPLE_SIZE, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
48+
fn in_place_scalar(bencher: divan::Bencher, (size, probability): (usize, f64)) {
4349
let mask = create_random_mask(size, probability);
4450
bencher
4551
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
4652
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
4753
}
4854

4955
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
50-
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
51-
fn random_probability_avx512(bencher: divan::Bencher, (size, probability): (usize, f64)) {
56+
#[divan::bench(sample_size = SAMPLE_SIZE, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
57+
fn in_place_avx512(bencher: divan::Bencher, (size, probability): (usize, f64)) {
5258
let mask = create_random_mask(size, probability);
5359
bencher
5460
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
5561
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
5662
}
5763

58-
const LARGE_SIZE: usize = 1024 * 1024; // 4 MB
59-
60-
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
61-
fn scalar_throughput(bencher: divan::Bencher, probability: f64) {
62-
let mask = create_random_mask(LARGE_SIZE, probability);
64+
#[divan::bench(sample_size = SAMPLE_SIZE, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
65+
fn out_scalar(bencher: divan::Bencher, (size, probability): (usize, f64)) {
66+
let mask = create_random_mask(size, probability);
6367
bencher
64-
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
65-
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
66-
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
68+
.with_inputs(|| {
69+
let src = (0..size as i32).collect::<Vec<_>>();
70+
let dest = vec![0i32; size];
71+
(src, dest)
72+
})
73+
.bench_values(|(src, mut dest)| filter_into_scalar(&src, &mut dest, &mask))
6774
}
6875

69-
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
7076
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
71-
fn avx512_throughput(bencher: divan::Bencher, probability: f64) {
72-
let mask = create_random_mask(LARGE_SIZE, probability);
77+
#[divan::bench(sample_size = SAMPLE_SIZE, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
78+
fn out_avx512(bencher: divan::Bencher, (size, probability): (usize, f64)) {
79+
let mask = create_random_mask(size, probability);
7380
bencher
74-
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
75-
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
76-
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
81+
.with_inputs(|| {
82+
let src = (0..size as i32).collect::<Vec<_>>();
83+
let dest = vec![0i32; size];
84+
(src, dest)
85+
})
86+
.bench_values(|(src, mut dest)| unsafe { filter_into_avx512(&src, &mut dest, &mask) })
7787
}

vortex-compute/src/filter/slice/in_place/avx512.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::filter::slice::in_place::filter_in_place_scalar;
1414
///
1515
/// The mask is represented as a slice of bytes (LSB is the first element).
1616
///
17-
/// Returns the true count of the mask.
17+
/// Returns the true count of the mask (number of elements remaining).
1818
///
1919
/// This function automatically dispatches to the most efficient implementation based on the
2020
/// available CPU features at compile time.

vortex-compute/src/filter/slice/in_place/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
77
pub mod avx512;
88

9+
// TODO(connor): This is super inefficient.
910
/// Filter a mutable slice of elements in-place depending on the given mask.
1011
///
1112
/// The mask is represented as a slice of bytes (LSB is the first element).
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Implementations of a specialized out-of-place filter for buffers using AVX512.
5+
6+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7+
use std::arch::x86_64::*;
8+
9+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10+
use crate::filter::slice::SimdCompress;
11+
use crate::filter::slice::out::filter_into_scalar;
12+
13+
/// Filter elements from a source slice into a destination slice based on the given mask.
14+
///
15+
/// The mask is represented as a slice of bytes (LSB is the first element).
16+
///
17+
/// Returns the true count of the mask (number of elements written to destination).
18+
///
19+
/// This function automatically dispatches to the most efficient implementation based on the
20+
/// available CPU features at compile time.
21+
///
22+
/// # Panics
23+
///
24+
/// Panics if:
25+
///
26+
/// - `mask.len() != src.len().div_ceil(8)`
27+
/// - `dest.len() < src.len()`
28+
#[inline]
29+
pub fn filter_into<T: SimdCompress>(src: &[T], dest: &mut [T], mask: &[u8]) -> usize {
30+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
31+
{
32+
let use_simd = if T::WIDTH >= 32 {
33+
// 32-bit and 64-bit types only need AVX-512F.
34+
is_x86_feature_detected!("avx512f")
35+
} else {
36+
// 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2.
37+
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2")
38+
};
39+
40+
if use_simd {
41+
return unsafe { filter_into_avx512(src, dest, mask) };
42+
}
43+
}
44+
45+
// Fall back to scalar implementation for non-x86 or when SIMD not available.
46+
filter_into_scalar(src, dest, mask)
47+
}
48+
49+
/// Filter elements from a source slice into a destination slice based on the given mask.
50+
///
51+
/// The mask is represented as a slice of bytes (LSB is the first element).
52+
///
53+
/// Returns the true count of the mask (number of elements written to destination).
54+
///
55+
/// This function uses AVX-512 SIMD instructions for high-performance filtering.
56+
///
57+
/// # Panics
58+
///
59+
/// Panics if:
60+
///
61+
/// - `mask.len() != src.len().div_ceil(8)`
62+
/// - `dest.len() < src.len()`
63+
///
64+
/// # Safety
65+
///
66+
/// This function requires the appropriate SIMD instruction set to be available.
67+
/// For AVX-512F types, the CPU must support AVX-512F.
68+
/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2.
69+
#[inline]
70+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
71+
#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")]
72+
pub unsafe fn filter_into_avx512<T: SimdCompress>(src: &[T], dest: &mut [T], mask: &[u8]) -> usize {
73+
assert_eq!(
74+
mask.len(),
75+
src.len().div_ceil(8),
76+
"Mask length must be src.len().div_ceil(8)"
77+
);
78+
assert!(
79+
dest.len() >= src.len(),
80+
"Destination buffer must be at least as large as source"
81+
);
82+
83+
let src_len = src.len();
84+
let mut write_pos = 0;
85+
86+
// Pre-calculate loop bounds to eliminate branch misprediction in the hot loop.
87+
let full_chunks = src_len / T::ELEMENTS_PER_VECTOR;
88+
let remainder = src_len % T::ELEMENTS_PER_VECTOR;
89+
90+
// Process full chunks with no branches in the loop.
91+
for chunk_idx in 0..full_chunks {
92+
let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR;
93+
let mask_byte_offset = chunk_idx * T::MASK_BYTES;
94+
95+
// Read the mask for this chunk.
96+
// SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks.
97+
let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) };
98+
99+
// Load elements from source into the SIMD register.
100+
// SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= src.len()` for all full chunks.
101+
let vector = unsafe { _mm512_loadu_si512(src.as_ptr().add(read_pos) as *const __m512i) };
102+
103+
// Moves all elements that have their bit set to 1 in the mask value to the left.
104+
let filtered = unsafe { T::compress_vector(mask_value, vector) };
105+
106+
// Write the filtered result vector to destination buffer.
107+
// SAFETY: `write_pos + count_ones(mask_value) <= dest.len()` since dest.len() >= src.len()
108+
// and we're only writing the selected elements.
109+
unsafe { _mm512_storeu_si512(dest.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) };
110+
111+
// Uses the hardware `popcnt` instruction if available.
112+
let count = T::count_ones(mask_value);
113+
write_pos += count;
114+
}
115+
116+
// Handle the final partial chunk with simple scalar processing.
117+
let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR;
118+
for i in 0..remainder {
119+
let read_idx = read_pos + i;
120+
let bit_idx = read_idx % 8;
121+
let byte_idx = read_idx / 8;
122+
123+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
124+
dest[write_pos] = src[read_idx];
125+
write_pos += 1;
126+
}
127+
}
128+
129+
write_pos
130+
}

0 commit comments

Comments
 (0)