Skip to content

Commit c731f04

Browse files
committed
add AVX512 support for filtering in place
Signed-off-by: Connor Tsui <[email protected]>
1 parent 3df9296 commit c731f04

File tree

11 files changed

+796
-69
lines changed

11 files changed

+796
-69
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-compute/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"]
3737

3838
[dev-dependencies]
3939
divan = { workspace = true }
40+
itertools = { workspace = true }
41+
rand = { workspace = true }
4042

4143
[[bench]]
4244
name = "filter_buffer_mut"
@@ -45,3 +47,7 @@ harness = false
4547
[[bench]]
4648
name = "expand_buffer"
4749
harness = false
50+
51+
[[bench]]
52+
name = "avx512"
53+
harness = false

vortex-compute/benches/avx512.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![expect(clippy::cast_possible_truncation)]
5+
6+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7+
use filter_in_place::filter_in_place_avx512;
8+
use itertools::Itertools;
9+
use rand::Rng;
10+
use vortex_compute::filter::slice::in_place::filter_in_place_scalar;
11+
12+
fn main() {
13+
divan::main();
14+
}
15+
16+
// Create a random mask where each bit has `probability` chance of being set.
17+
fn create_random_mask(size: usize, probability: f64) -> Vec<u8> {
18+
let mut rng = rand::rng();
19+
let num_bytes = size.div_ceil(8);
20+
let mut mask = Vec::with_capacity(num_bytes);
21+
22+
for _ in 0..num_bytes {
23+
let mut byte = 0u8;
24+
for bit in 0..8 {
25+
if rng.random::<f64>() < probability {
26+
byte |= 1 << bit;
27+
}
28+
}
29+
mask.push(byte);
30+
}
31+
32+
mask
33+
}
34+
35+
// Benchmark different data sizes.
36+
const SIZES: &[usize] = &[1 << 10, 1 << 14, 1 << 17];
37+
38+
// Different probability values to benchmark.
39+
const PROBABILITIES: &[f64] = &[0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0];
40+
41+
#[divan::bench_group]
42+
mod filter_scalar_i32 {
43+
use super::*;
44+
45+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
46+
fn random_probability(bencher: divan::Bencher, (size, probability): (usize, f64)) {
47+
let mask = create_random_mask(size, probability);
48+
bencher
49+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
50+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
51+
}
52+
}
53+
54+
#[divan::bench_group]
55+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
56+
mod filter_avx512_i32 {
57+
use super::*;
58+
59+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
60+
fn random_probability(bencher: divan::Bencher, (size, probability): (usize, f64)) {
61+
let mask = create_random_mask(size, probability);
62+
bencher
63+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
64+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
65+
}
66+
}
67+
68+
// Throughput benchmark - measure GB/s
69+
#[divan::bench_group]
70+
mod throughput {
71+
use super::*;
72+
73+
const LARGE_SIZE: usize = 1024 * 1024; // 4 MB
74+
75+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
76+
fn scalar_throughput(bencher: divan::Bencher, probability: f64) {
77+
let mask = create_random_mask(LARGE_SIZE, probability);
78+
bencher
79+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
80+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
81+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
82+
}
83+
84+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
85+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
86+
fn avx512_throughput(bencher: divan::Bencher, probability: f64) {
87+
let mask = create_random_mask(LARGE_SIZE, probability);
88+
bencher
89+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
90+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
91+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
92+
}
93+
}

vortex-compute/src/filter/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
mod bitbuffer;
77
mod buffer;
88
mod mask;
9-
mod slice_mut;
9+
pub mod slice;
1010
mod vector;
1111

1212
/// Function for filtering based on a selection mask.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
5+
use std::arch::x86_64::*;
6+
7+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8+
use crate::filter::slice::SimdCompress;
9+
use crate::filter::slice::in_place::filter_in_place_scalar;
10+
11+
/// Filter a mutable slice of elements in-place depending on the given mask.
12+
///
13+
/// The mask is represented as a slice of bytes (LSB is the first element).
14+
///
15+
/// Returns the true count of the mask.
16+
///
17+
/// This function automatically dispatches to the most efficient implementation based on the
18+
/// available CPU features at compile time.
19+
///
20+
/// # Panics
21+
///
22+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
23+
#[inline]
24+
pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
25+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
26+
{
27+
let use_simd = if T::WIDTH >= 32 {
28+
// 32-bit and 64-bit types only need AVX-512F.
29+
is_x86_feature_detected!("avx512f")
30+
} else {
31+
// 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2.
32+
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2")
33+
};
34+
35+
if use_simd {
36+
return unsafe { filter_in_place_avx512(data, mask) };
37+
}
38+
}
39+
40+
// Fall back to scalar implementation for non-x86 or when SIMD not available.
41+
filter_in_place_scalar(data, mask)
42+
}
43+
44+
/// Filter a mutable slice of elements in-place depending on the given mask.
45+
///
46+
/// The mask is represented as a slice of bytes (LSB is the first element).
47+
///
48+
/// Returns the true count of the mask.
49+
///
50+
/// This function uses AVX-512 SIMD instructions for high-performance filtering.
51+
///
52+
/// # Panics
53+
///
54+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
55+
///
56+
/// # Safety
57+
///
58+
/// This function requires the appropriate SIMD instruction set to be available.
59+
/// For AVX-512F types, the CPU must support AVX-512F.
60+
/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2.
61+
#[inline]
62+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63+
#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")]
64+
pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
65+
assert_eq!(
66+
mask.len(),
67+
data.len().div_ceil(8),
68+
"Mask length must be data.len().div_ceil(8)"
69+
);
70+
71+
let data_len = data.len();
72+
let mut write_pos = 0;
73+
74+
// Pre-calculate loop bounds to eliminate branch misprediction in the hot loop.
75+
let full_chunks = data_len / T::ELEMENTS_PER_VECTOR;
76+
let remainder = data_len % T::ELEMENTS_PER_VECTOR;
77+
78+
// Process full chunks with no branches in the loop.
79+
for chunk_idx in 0..full_chunks {
80+
let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR;
81+
let mask_byte_offset = chunk_idx * T::MASK_BYTES;
82+
83+
// Read the mask for this chunk.
84+
// SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks.
85+
let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) };
86+
87+
// Load elements into the SIMD register.
88+
// SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks.
89+
let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) };
90+
91+
// Moves all elements that have their bit set to 1 in the mask value to the left.
92+
let filtered = unsafe { T::compress_vector(mask_value, vector) };
93+
94+
// Write the filtered result vector back to memory.
95+
// SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting.
96+
unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) };
97+
98+
// Uses the hardware `popcnt` instruction if available.
99+
let count = T::count_ones(mask_value);
100+
write_pos += count;
101+
}
102+
103+
// Handle the final partial chunk with simple scalar processing.
104+
let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR;
105+
for i in 0..remainder {
106+
let read_idx = read_pos + i;
107+
let bit_idx = read_idx % 8;
108+
let byte_idx = read_idx / 8;
109+
110+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
111+
data[write_pos] = data[read_idx];
112+
write_pos += 1;
113+
}
114+
}
115+
116+
write_pos
117+
}

0 commit comments

Comments
 (0)