|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +#![cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 5 | + |
| 6 | +use std::arch::x86_64::*; |
| 7 | + |
| 8 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 9 | +use crate::filter::avx512::SimdCompress; |
| 10 | + |
| 11 | +/// Filter a mutable slice of elements in-place depending on the given mask. |
| 12 | +/// |
| 13 | +/// The mask is represented as a slice of bytes (LSB is the first element). |
| 14 | +/// |
| 15 | +/// Returns the true count of the mask. |
| 16 | +/// |
| 17 | +/// This function automatically dispatches to the most efficient implementation based on the |
| 18 | +/// available CPU features at compile time. |
| 19 | +/// |
| 20 | +/// # Panics |
| 21 | +/// |
| 22 | +/// Panics if `mask.len() != data.len().div_ceil(8)`. |
| 23 | +#[inline] |
| 24 | +pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize { |
| 25 | + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 26 | + { |
| 27 | + let use_simd = if T::WIDTH >= 32 { |
| 28 | + // 32-bit and 64-bit types only need AVX-512F. |
| 29 | + is_x86_feature_detected!("avx512f") |
| 30 | + } else { |
| 31 | + // 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2. |
| 32 | + is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2") |
| 33 | + }; |
| 34 | + |
| 35 | + if use_simd { |
| 36 | + return unsafe { filter_in_place_avx512(data, mask) }; |
| 37 | + } |
| 38 | + } |
| 39 | + |
| 40 | + // Fall back to scalar implementation for non-x86 or when SIMD not available. |
| 41 | + filter_in_place_scalar(data, mask) |
| 42 | +} |
| 43 | + |
| 44 | +/// Filter a mutable slice of elements in-place depending on the given mask. |
| 45 | +/// |
| 46 | +/// The mask is represented as a slice of bytes (LSB is the first element). |
| 47 | +/// |
| 48 | +/// Returns the true count of the mask. |
| 49 | +/// |
| 50 | +/// This function uses a scalar implementation that simply uses a read and write pointer to write |
| 51 | +/// values to the correct places in memory. |
| 52 | +/// |
| 53 | +/// # Panics |
| 54 | +/// |
| 55 | +/// Panics if `mask.len() != data.len().div_ceil(8)`. |
| 56 | +#[inline] |
| 57 | +pub fn filter_in_place_scalar<T: Copy>(data: &mut [T], mask: &[u8]) -> usize { |
| 58 | + assert_eq!( |
| 59 | + mask.len(), |
| 60 | + data.len().div_ceil(8), |
| 61 | + "Mask length must be data.len().div_ceil(8)" |
| 62 | + ); |
| 63 | + |
| 64 | + let mut write_pos = 0; |
| 65 | + let data_len = data.len(); |
| 66 | + |
| 67 | + for read_pos in 0..data_len { |
| 68 | + let byte_idx = read_pos / 8; |
| 69 | + let bit_idx = read_pos % 8; |
| 70 | + |
| 71 | + if (mask[byte_idx] >> bit_idx) & 1 == 1 { |
| 72 | + data[write_pos] = data[read_pos]; |
| 73 | + write_pos += 1; |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + write_pos |
| 78 | +} |
| 79 | + |
| 80 | +/// Filter a mutable slice of elements in-place depending on the given mask. |
| 81 | +/// |
| 82 | +/// The mask is represented as a slice of bytes (LSB is the first element). |
| 83 | +/// |
| 84 | +/// Returns the true count of the mask. |
| 85 | +/// |
| 86 | +/// This function uses AVX-512 SIMD instructions for high-performance filtering. |
| 87 | +/// |
| 88 | +/// # Panics |
| 89 | +/// |
| 90 | +/// Panics if `mask.len() != data.len().div_ceil(8)`. |
| 91 | +/// |
| 92 | +/// # Safety |
| 93 | +/// |
| 94 | +/// This function requires the appropriate SIMD instruction set to be available. |
| 95 | +/// For AVX-512F types, the CPU must support AVX-512F. |
| 96 | +/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2. |
| 97 | +#[inline] |
| 98 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 99 | +#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")] |
| 100 | +pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize { |
| 101 | + assert_eq!( |
| 102 | + mask.len(), |
| 103 | + data.len().div_ceil(8), |
| 104 | + "Mask length must be data.len().div_ceil(8)" |
| 105 | + ); |
| 106 | + |
| 107 | + let data_len = data.len(); |
| 108 | + let mut write_pos = 0; |
| 109 | + |
| 110 | + // Pre-calculate loop bounds to eliminate branch misprediction in the hot loop. |
| 111 | + let full_chunks = data_len / T::ELEMENTS_PER_VECTOR; |
| 112 | + let remainder = data_len % T::ELEMENTS_PER_VECTOR; |
| 113 | + |
| 114 | + // Process full chunks with no branches in the loop. |
| 115 | + for chunk_idx in 0..full_chunks { |
| 116 | + let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR; |
| 117 | + let mask_byte_offset = chunk_idx * T::MASK_BYTES; |
| 118 | + |
| 119 | + // Read the mask for this chunk. |
| 120 | + // SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks. |
| 121 | + let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) }; |
| 122 | + |
| 123 | + // Load elements into the SIMD register. |
| 124 | + // SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks. |
| 125 | + let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) }; |
| 126 | + |
| 127 | + // Moves all elements that have their bit set to 1 in the mask value to the left. |
| 128 | + let filtered = unsafe { T::compress_vector(mask_value, vector) }; |
| 129 | + |
| 130 | + // Write the filtered result vector back to memory. |
| 131 | + // SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting. |
| 132 | + unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) }; |
| 133 | + |
| 134 | + // Uses the hardware `popcnt` instruction if available. |
| 135 | + let count = T::count_ones(mask_value); |
| 136 | + write_pos += count; |
| 137 | + } |
| 138 | + |
| 139 | + // Handle the final partial chunk with simple scalar processing. |
| 140 | + let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR; |
| 141 | + for i in 0..remainder { |
| 142 | + let read_idx = read_pos + i; |
| 143 | + let bit_idx = read_idx % 8; |
| 144 | + let byte_idx = read_idx / 8; |
| 145 | + |
| 146 | + if (mask[byte_idx] >> bit_idx) & 1 == 1 { |
| 147 | + data[write_pos] = data[read_idx]; |
| 148 | + write_pos += 1; |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + write_pos |
| 153 | +} |
| 154 | + |
| 155 | +#[cfg(test)] |
| 156 | +mod tests { |
| 157 | + use super::*; |
| 158 | + |
| 159 | + fn create_mask(bits: &[bool]) -> Vec<u8> { |
| 160 | + let mut mask = vec![0u8; bits.len().div_ceil(8)]; |
| 161 | + for (i, &bit) in bits.iter().enumerate() { |
| 162 | + if bit { |
| 163 | + mask[i / 8] |= 1 << (i % 8); |
| 164 | + } |
| 165 | + } |
| 166 | + mask |
| 167 | + } |
| 168 | + |
| 169 | + fn test_implementation<F>(filter_fn: F) |
| 170 | + where |
| 171 | + F: Fn(&mut [i32], &[u8]) -> usize, |
| 172 | + { |
| 173 | + // Test 1: Small array - all elements pass |
| 174 | + let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7]; |
| 175 | + let mask = vec![0xFF]; // All 1s |
| 176 | + let count = filter_fn(&mut data, &mask); |
| 177 | + assert_eq!(count, 8); |
| 178 | + assert_eq!(&data[..8], &[0, 1, 2, 3, 4, 5, 6, 7]); |
| 179 | + |
| 180 | + // Test 2: Small array - no elements pass |
| 181 | + let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7]; |
| 182 | + let mask = vec![0x00]; // All 0s |
| 183 | + let count = filter_fn(&mut data, &mask); |
| 184 | + assert_eq!(count, 0); |
| 185 | + |
| 186 | + // Test 3: Small array - every other element |
| 187 | + let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7]; |
| 188 | + let mask = vec![0x55]; // 01010101 |
| 189 | + let count = filter_fn(&mut data, &mask); |
| 190 | + assert_eq!(count, 4); |
| 191 | + assert_eq!(&data[..4], &[0, 2, 4, 6]); |
| 192 | + |
| 193 | + // Test 4: 16 elements - all pass |
| 194 | + let mut data: Vec<i32> = (0..16).collect(); |
| 195 | + let mask = vec![0xFF, 0xFF]; |
| 196 | + let count = filter_fn(&mut data, &mask); |
| 197 | + assert_eq!(count, 16); |
| 198 | + assert_eq!( |
| 199 | + &data[..16], |
| 200 | + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] |
| 201 | + ); |
| 202 | + |
| 203 | + // Test 5: 16 elements - alternating pattern |
| 204 | + let mut data: Vec<i32> = (0..16).collect(); |
| 205 | + let mask = vec![0xAA, 0xAA]; // 10101010 10101010 |
| 206 | + let count = filter_fn(&mut data, &mask); |
| 207 | + assert_eq!(count, 8); |
| 208 | + assert_eq!(&data[..8], &[1, 3, 5, 7, 9, 11, 13, 15]); |
| 209 | + |
| 210 | + // Test 6: Larger array (32 elements) |
| 211 | + let mut data: Vec<i32> = (0..32).collect(); |
| 212 | + let mask = vec![0xFF, 0x00, 0xFF, 0x00]; // First and third bytes |
| 213 | + let count = filter_fn(&mut data, &mask); |
| 214 | + assert_eq!(count, 16); |
| 215 | + assert_eq!( |
| 216 | + &data[..16], |
| 217 | + &[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] |
| 218 | + ); |
| 219 | + |
| 220 | + // Test 7: Non-aligned size (23 elements) |
| 221 | + let mut data: Vec<i32> = (0..23).collect(); |
| 222 | + let mask = create_mask(&[ |
| 223 | + true, false, true, false, true, false, true, false, // byte 0 |
| 224 | + false, true, false, true, false, true, false, true, // byte 1 |
| 225 | + true, true, false, false, true, true, false, // byte 2 (partial) |
| 226 | + ]); |
| 227 | + let count = filter_fn(&mut data, &mask); |
| 228 | + assert_eq!(count, 12); |
| 229 | + assert_eq!(&data[..12], &[0, 2, 4, 6, 9, 11, 13, 15, 16, 17, 20, 21]); |
| 230 | + } |
| 231 | + |
| 232 | + #[test] |
| 233 | + fn test_scalar() { |
| 234 | + test_implementation(filter_in_place_scalar::<i32>); |
| 235 | + } |
| 236 | + |
| 237 | + #[test] |
| 238 | + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 239 | + fn test_avx512() { |
| 240 | + test_implementation(|data, mask| unsafe { filter_in_place_avx512::<i32>(data, mask) }); |
| 241 | + } |
| 242 | + |
| 243 | + #[test] |
| 244 | + fn test_runtime_dispatch() { |
| 245 | + test_implementation(filter_in_place::<i32>); |
| 246 | + } |
| 247 | + |
| 248 | + #[test] |
| 249 | + fn test_all_implementations_match() { |
| 250 | + // Test that all available implementations produce the same results |
| 251 | + |
| 252 | + // Test various sizes and patterns |
| 253 | + let test_cases = vec![ |
| 254 | + (8, vec![0xAA]), // 8 elements, alternating |
| 255 | + (16, vec![0xFF, 0xFF]), // 16 elements, all pass |
| 256 | + (16, vec![0x00, 0x00]), // 16 elements, none pass |
| 257 | + (32, vec![0x55, 0x55, 0x55, 0x55]), // 32 elements, alternating |
| 258 | + (24, vec![0xFF, 0x00, 0xFF]), // 24 elements, mixed |
| 259 | + (100, vec![0xFF; 13]), // 100 elements (needs 13 bytes) |
| 260 | + ]; |
| 261 | + |
| 262 | + for (size, mask) in test_cases { |
| 263 | + let mut data_scalar: Vec<i32> = (0..size).collect(); |
| 264 | + |
| 265 | + let count_scalar = filter_in_place_scalar::<i32>(&mut data_scalar, &mask); |
| 266 | + |
| 267 | + // Test AVX-512 on x86/x86_64 |
| 268 | + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 269 | + { |
| 270 | + let mut data_avx512: Vec<i32> = (0..size).collect(); |
| 271 | + let count_avx512 = |
| 272 | + unsafe { filter_in_place_avx512::<i32>(&mut data_avx512, &mask) }; |
| 273 | + assert_eq!( |
| 274 | + count_scalar, count_avx512, |
| 275 | + "Count mismatch for size {}", |
| 276 | + size |
| 277 | + ); |
| 278 | + assert_eq!( |
| 279 | + &data_scalar[..count_scalar], |
| 280 | + &data_avx512[..count_avx512], |
| 281 | + "Data mismatch for size {}", |
| 282 | + size |
| 283 | + ); |
| 284 | + } |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + #[expect(clippy::cast_possible_truncation)] |
| 289 | + #[test] |
| 290 | + fn test_large_arrays() { |
| 291 | + // Test with very large arrays to ensure chunking works correctly |
| 292 | + let sizes: Vec<usize> = vec![1024, 1000, 2048, 4096, 10000]; |
| 293 | + |
| 294 | + for size in sizes { |
| 295 | + let mut data: Vec<i32> = (0..size as i32).collect(); |
| 296 | + // Create alternating mask |
| 297 | + let mut mask = vec![0u8; size.div_ceil(8)]; |
| 298 | + mask.fill(0x55); // 01010101 |
| 299 | + |
| 300 | + let count = filter_in_place::<i32>(&mut data, &mask); |
| 301 | + assert_eq!(count, size / 2); |
| 302 | + |
| 303 | + // Verify first few and last few elements |
| 304 | + (0..10.min(count)).for_each(|i| { |
| 305 | + assert_eq!(data[i], (i * 2) as i32); |
| 306 | + }); |
| 307 | + } |
| 308 | + } |
| 309 | +} |
0 commit comments