Skip to content

Commit 3e00e8f

Browse files
committed
add AVX512 support for filtering in place
Signed-off-by: Connor Tsui <[email protected]>
1 parent 3df9296 commit 3e00e8f

File tree

11 files changed

+791
-69
lines changed

11 files changed

+791
-69
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-compute/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"]
3737

3838
[dev-dependencies]
3939
divan = { workspace = true }
40+
rand = { workspace = true }
41+
itertools = { workspace = true }
4042

4143
[[bench]]
4244
name = "filter_buffer_mut"
@@ -45,3 +47,7 @@ harness = false
4547
[[bench]]
4648
name = "expand_buffer"
4749
harness = false
50+
51+
[[bench]]
52+
name = "avx512"
53+
harness = false

vortex-compute/benches/avx512.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![expect(clippy::cast_possible_truncation)]
5+
6+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7+
use filter_in_place::filter_in_place_avx512;
8+
use itertools::Itertools;
9+
use rand::Rng;
10+
use vortex_compute::filter::slice::in_place::filter_in_place_scalar;
11+
12+
fn main() {
13+
divan::main();
14+
}
15+
16+
// Create a random mask where each bit has `probability` chance of being set.
17+
fn create_random_mask(size: usize, probability: f64) -> Vec<u8> {
18+
let mut rng = rand::rng();
19+
let num_bytes = size.div_ceil(8);
20+
let mut mask = Vec::with_capacity(num_bytes);
21+
22+
for _ in 0..num_bytes {
23+
let mut byte = 0u8;
24+
for bit in 0..8 {
25+
if rng.random::<f64>() < probability {
26+
byte |= 1 << bit;
27+
}
28+
}
29+
mask.push(byte);
30+
}
31+
32+
mask
33+
}
34+
35+
// Benchmark different data sizes.
36+
const SIZES: &[usize] = &[1 << 10, 1 << 14, 1 << 17];
37+
38+
// Different probability values to benchmark.
39+
const PROBABILITIES: &[f64] = &[0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0];
40+
41+
#[divan::bench_group]
42+
mod filter_scalar_i32 {
43+
use super::*;
44+
45+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
46+
fn random_probability(bencher: divan::Bencher, (size, probability): (usize, f64)) {
47+
let mask = create_random_mask(size, probability);
48+
bencher
49+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
50+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
51+
}
52+
}
53+
54+
#[divan::bench_group]
55+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
56+
mod filter_avx512_i32 {
57+
use super::*;
58+
59+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
60+
fn random_probability(bencher: divan::Bencher, (size, probability): (usize, f64)) {
61+
let mask = create_random_mask(size, probability);
62+
bencher
63+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
64+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
65+
}
66+
}
67+
68+
// Throughput benchmark - measure GB/s
69+
#[divan::bench_group]
70+
mod throughput {
71+
use super::*;
72+
73+
const LARGE_SIZE: usize = 1024 * 1024; // 4 MB
74+
75+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
76+
fn scalar_throughput(bencher: divan::Bencher, probability: f64) {
77+
let mask = create_random_mask(LARGE_SIZE, probability);
78+
bencher
79+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
80+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
81+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
82+
}
83+
84+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
85+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
86+
fn avx512_throughput(bencher: divan::Bencher, probability: f64) {
87+
let mask = create_random_mask(LARGE_SIZE, probability);
88+
bencher
89+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
90+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
91+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
92+
}
93+
}

vortex-compute/src/filter/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
mod bitbuffer;
77
mod buffer;
88
mod mask;
9-
mod slice_mut;
9+
pub mod slice;
1010
mod vector;
1111

1212
/// Function for filtering based on a selection mask.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
5+
use std::arch::x86_64::*;
6+
7+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8+
use crate::filter::slice::SimdCompress;
9+
10+
/// Filter a mutable slice of elements in-place depending on the given mask.
11+
///
12+
/// The mask is represented as a slice of bytes (LSB is the first element).
13+
///
14+
/// Returns the true count of the mask.
15+
///
16+
/// This function automatically dispatches to the most efficient implementation based on the
17+
/// available CPU features at compile time.
18+
///
19+
/// # Panics
20+
///
21+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
22+
#[inline]
23+
pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
24+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
25+
{
26+
let use_simd = if T::WIDTH >= 32 {
27+
// 32-bit and 64-bit types only need AVX-512F.
28+
is_x86_feature_detected!("avx512f")
29+
} else {
30+
// 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2.
31+
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2")
32+
};
33+
34+
if use_simd {
35+
return unsafe { filter_in_place_avx512(data, mask) };
36+
}
37+
}
38+
39+
// Fall back to scalar implementation for non-x86 or when SIMD not available.
40+
filter_in_place_scalar(data, mask)
41+
}
42+
43+
/// Filter a mutable slice of elements in-place depending on the given mask.
44+
///
45+
/// The mask is represented as a slice of bytes (LSB is the first element).
46+
///
47+
/// Returns the true count of the mask.
48+
///
49+
/// This function uses AVX-512 SIMD instructions for high-performance filtering.
50+
///
51+
/// # Panics
52+
///
53+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
54+
///
55+
/// # Safety
56+
///
57+
/// This function requires the appropriate SIMD instruction set to be available.
58+
/// For AVX-512F types, the CPU must support AVX-512F.
59+
/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2.
60+
#[inline]
61+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
62+
#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")]
63+
pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
64+
assert_eq!(
65+
mask.len(),
66+
data.len().div_ceil(8),
67+
"Mask length must be data.len().div_ceil(8)"
68+
);
69+
70+
let data_len = data.len();
71+
let mut write_pos = 0;
72+
73+
// Pre-calculate loop bounds to eliminate branch misprediction in the hot loop.
74+
let full_chunks = data_len / T::ELEMENTS_PER_VECTOR;
75+
let remainder = data_len % T::ELEMENTS_PER_VECTOR;
76+
77+
// Process full chunks with no branches in the loop.
78+
for chunk_idx in 0..full_chunks {
79+
let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR;
80+
let mask_byte_offset = chunk_idx * T::MASK_BYTES;
81+
82+
// Read the mask for this chunk.
83+
// SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks.
84+
let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) };
85+
86+
// Load elements into the SIMD register.
87+
// SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks.
88+
let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) };
89+
90+
// Moves all elements that have their bit set to 1 in the mask value to the left.
91+
let filtered = unsafe { T::compress_vector(mask_value, vector) };
92+
93+
// Write the filtered result vector back to memory.
94+
// SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting.
95+
unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) };
96+
97+
// Uses the hardware `popcnt` instruction if available.
98+
let count = T::count_ones(mask_value);
99+
write_pos += count;
100+
}
101+
102+
// Handle the final partial chunk with simple scalar processing.
103+
let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR;
104+
for i in 0..remainder {
105+
let read_idx = read_pos + i;
106+
let bit_idx = read_idx % 8;
107+
let byte_idx = read_idx / 8;
108+
109+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
110+
data[write_pos] = data[read_idx];
111+
write_pos += 1;
112+
}
113+
}
114+
115+
write_pos
116+
}

0 commit comments

Comments
 (0)