Skip to content

Commit d7a409a

Browse files
Add AVX-512BW path for ascii_lower, gated on avx512fp16
Adds a 64-byte-wide back-end (`_mm512_cmpgt_epi8_mask` + `_mm512_movm_epi8`) to `utils::simd::ascii_lower`. The dispatcher only routes to it when `avx512f`, `avx512bw`, and `avx512fp16` are all detected. `avx512fp16` is used as a proxy for "AVX-512 without meaningful license-mode downclock": - present on Intel Sapphire Rapids / Emerald Rapids / Granite Rapids - present on AMD Zen 4 (Ryzen 7000 / EPYC Genoa) and Zen 5 (Turin) - absent on Skylake-X, Cascade Lake, Cooper Lake, Ice Lake-SP, Rocket Lake — exactly the generations where 512-bit ops cause measurable frequency throttling Older AVX-512-capable hardware therefore stays on the AVX2 path, where the 256-bit work is already memory-bandwidth-bound on long buffers. Verified with `cargo check --target x86_64-unknown-linux-gnu` plus the existing 207-test lib suite on aarch64. The AVX-512 path itself is exercised at runtime only on capable hosts. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 906ceca commit d7a409a

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

tokenizers/src/utils/simd.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
pub fn ascii_lower(buf: &mut [u8]) {
1313
#[cfg(target_arch = "x86_64")]
1414
{
15+
// Only enable the AVX-512 path on CPUs where the AVX-512 license-mode
16+
// downclock is negligible. `avx512fp16` is used as a proxy: it is
17+
// present on Intel Sapphire Rapids+ and on AMD Zen 4/Zen 5, and is
18+
// absent on Skylake-X / Cascade Lake / Cooper Lake / Ice Lake-SP /
19+
// Rocket Lake — i.e. exactly the generations where 512-bit ops trigger
20+
// significant frequency throttling.
21+
if std::is_x86_feature_detected!("avx512f")
22+
&& std::is_x86_feature_detected!("avx512bw")
23+
&& std::is_x86_feature_detected!("avx512fp16")
24+
{
25+
unsafe { return ascii_lower_avx512(buf) };
26+
}
1527
if std::is_x86_feature_detected!("avx2") {
1628
unsafe { return ascii_lower_avx2(buf) };
1729
}
@@ -38,6 +50,35 @@ fn ascii_lower_scalar(buf: &mut [u8]) {
3850
}
3951
}
4052

53+
#[cfg(target_arch = "x86_64")]
54+
#[target_feature(enable = "avx512f,avx512bw")]
55+
unsafe fn ascii_lower_avx512(buf: &mut [u8]) {
56+
use std::arch::x86_64::*;
57+
58+
let a_minus_1 = _mm512_set1_epi8(b'A' as i8 - 1);
59+
let z_plus_1 = _mm512_set1_epi8(b'Z' as i8 + 1);
60+
let case_bit = _mm512_set1_epi8(0x20);
61+
62+
let len = buf.len();
63+
let mut i = 0;
64+
while i + 64 <= len {
65+
let p = buf.as_mut_ptr().add(i) as *mut __m512i;
66+
let v = _mm512_loadu_si512(p as *const __m512i);
67+
// `__mmask64` is `u64` in Rust; the two range checks AND together as a
68+
// plain bitwise op. Signed compares are correct here for the same
69+
// reason as the SSE2/AVX2 paths: ASCII bytes are < 0x80.
70+
let gt_a = _mm512_cmpgt_epi8_mask(v, a_minus_1);
71+
let lt_z = _mm512_cmpgt_epi8_mask(z_plus_1, v);
72+
let mask = gt_a & lt_z;
73+
let mask_vec = _mm512_movm_epi8(mask);
74+
let flip = _mm512_and_si512(mask_vec, case_bit);
75+
let out = _mm512_xor_si512(v, flip);
76+
_mm512_storeu_si512(p, out);
77+
i += 64;
78+
}
79+
ascii_lower_scalar(&mut buf[i..]);
80+
}
81+
4182
#[cfg(target_arch = "x86_64")]
4283
#[target_feature(enable = "avx2")]
4384
unsafe fn ascii_lower_avx2(buf: &mut [u8]) {

0 commit comments

Comments
 (0)