Skip to content

Commit 906ceca

Browse files
SIMD ASCII fast path for Lowercase normalizer
Adds `utils::simd::ascii_lower`, a runtime-dispatched in-place ASCII lowercaser with AVX2, SSE2, NEON, and scalar back-ends, and gates `NormalizedString::lowercase` on `is_ascii()` so all-ASCII inputs skip the per-`char` Unicode case-folding loop and the alignments rebuild. For ASCII inputs the slow path produced byte-identical output and byte-identical alignments (each `char::to_lowercase()` of an ASCII char is a single ASCII char, so the `transform` rebuild was a no-op on alignments); the fast path therefore just flips the `0x20` bit on bytes in `A`..=`Z` in place. Two new unit tests in `normalizer.rs` lock that equivalence in: - `lowercase_ascii_fast_path_preserves_alignments` — checks that after a non-trivial NFKD transform the fast path leaves `alignments`, `original`, and `original_shift` unchanged. - `lowercase_ascii_matches_unicode_path_byte_for_byte` — checks every printable ASCII byte against `char::to_lowercase`. Microbench (`benches/ascii_lower_benchmark.rs`, Apple Silicon, NEON): | size | SIMD | scalar (auto-vec) | unicode chars (old) | |--------|------------|-------------------|---------------------| | 64 B | 30.2 GiB/s | 5.9 GiB/s | 1.0 GiB/s | | 1 KiB | 55.4 GiB/s | 6.3 GiB/s | 1.2 GiB/s | | 16 KiB | 58.4 GiB/s | 6.3 GiB/s | 1.2 GiB/s | |256 KiB | 57.9 GiB/s | 6.3 GiB/s | 1.2 GiB/s | i.e. ~30-49x over the previous Unicode path on real-text-like ASCII. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3992692 commit 906ceca

5 files changed

Lines changed: 315 additions & 0 deletions

File tree

tokenizers/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ harness = false
6666
name = "ci_benchmark"
6767
harness = false
6868

69+
[[bench]]
70+
name = "ascii_lower_benchmark"
71+
harness = false
72+
6973
[dependencies]
7074
rand = "0.9"
7175
onig = { version = "6.5.1", default-features = false, optional = true }
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//! Microbenchmark for the ASCII lowercase fast path used by the `Lowercase`
2+
//! normalizer. Compares the SIMD-dispatched `utils::simd::ascii_lower` against
3+
//! the scalar reference and against the previous Unicode-aware path
4+
//! (`char::to_lowercase` per char) on representative buffer sizes.
5+
6+
#[macro_use]
7+
extern crate criterion;
8+
9+
use criterion::{Criterion, Throughput};
10+
use std::hint::black_box;
11+
use tokenizers::utils::simd::ascii_lower;
12+
13+
fn make_buffer(len: usize) -> Vec<u8> {
14+
// Cycle through printable ASCII (mix of upper, lower, digits, punctuation)
15+
// so the upper-case branch fires on roughly 26/95 of bytes.
16+
(0..len).map(|i| 0x20u8 + (i as u8 % 0x5F)).collect()
17+
}
18+
19+
fn scalar_lower(buf: &mut [u8]) {
20+
for b in buf {
21+
if b.is_ascii_uppercase() {
22+
*b |= 0x20;
23+
}
24+
}
25+
}
26+
27+
fn unicode_lower(buf: &str) -> String {
28+
// Mirrors what `NormalizedString::lowercase` did before the fast path: per
29+
// `char` UTF-8 decode + Unicode case folding, ignoring alignment bookkeeping.
30+
let mut out = String::with_capacity(buf.len());
31+
for c in buf.chars() {
32+
for lc in c.to_lowercase() {
33+
out.push(lc);
34+
}
35+
}
36+
out
37+
}
38+
39+
pub fn bench_ascii_lower(c: &mut Criterion) {
40+
for &len in &[64usize, 1024, 16 * 1024, 256 * 1024] {
41+
let mut group = c.benchmark_group(format!("ascii_lower/{len}B"));
42+
group.throughput(Throughput::Bytes(len as u64));
43+
44+
let original = make_buffer(len);
45+
46+
group.bench_function("simd", |b| {
47+
let mut buf = original.clone();
48+
b.iter(|| {
49+
ascii_lower(black_box(&mut buf));
50+
});
51+
});
52+
53+
group.bench_function("scalar", |b| {
54+
let mut buf = original.clone();
55+
b.iter(|| {
56+
scalar_lower(black_box(&mut buf));
57+
});
58+
});
59+
60+
// Stand-in for the pre-SIMD code path.
61+
let s = String::from_utf8(original.clone()).unwrap();
62+
group.bench_function("unicode_chars", |b| {
63+
b.iter(|| {
64+
black_box(unicode_lower(black_box(&s)));
65+
});
66+
});
67+
68+
group.finish();
69+
}
70+
}
71+
72+
criterion_group! {
73+
name = ascii_lower_bench;
74+
config = Criterion::default().sample_size(50);
75+
targets = bench_ascii_lower
76+
}
77+
criterion_main!(ascii_lower_bench);

tokenizers/src/tokenizer/normalizer.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,16 @@ impl NormalizedString {
544544

545545
/// Lowercase
546546
pub fn lowercase(&mut self) -> &mut Self {
547+
// ASCII fast path: each `A`..=`Z` becomes a single-byte `a`..=`z`,
548+
// so byte length and per-byte alignments are unchanged. We can mutate
549+
// bytes in place and skip the Unicode-aware `transform` rebuild.
550+
if self.normalized.is_ascii() {
551+
// Safety: `ascii_lower` only flips `0x20` on bytes already in
552+
// `b'A'..=b'Z'` (all < 0x80), so the result remains valid UTF-8.
553+
let bytes = unsafe { self.normalized.as_bytes_mut() };
554+
crate::utils::simd::ascii_lower(bytes);
555+
return self;
556+
}
547557
let mut new_chars: Vec<(char, isize)> = vec![];
548558
self.for_each(|c| {
549559
c.to_lowercase().enumerate().for_each(|(index, c)| {
@@ -2289,6 +2299,45 @@ mod tests {
22892299
assert_eq!(s.get(), "a...");
22902300
}
22912301

2302+
#[test]
2303+
fn lowercase_ascii_fast_path_preserves_alignments() {
2304+
// After a non-trivial transform (here NFKD on a ligature) the alignments
2305+
// map several normalized bytes back onto fewer original bytes. The ASCII
2306+
// fast path must leave that mapping byte-for-byte unchanged.
2307+
let mut n = NormalizedString::from("ABC\u{FB00}DEF"); // "ABCffDEF"; ff -> "ff" via NFKD
2308+
n.nfkd();
2309+
// Sanity: result is now all ASCII so the fast path will trigger.
2310+
assert!(n.get().is_ascii());
2311+
2312+
let bytes_before = n.normalized.clone();
2313+
let alignments_before = n.alignments.clone();
2314+
let original_before = n.original.clone();
2315+
let shift_before = n.original_shift;
2316+
2317+
n.lowercase();
2318+
2319+
assert_eq!(
2320+
n.normalized,
2321+
bytes_before.to_lowercase(),
2322+
"bytes mismatch fast vs ASCII to_lowercase"
2323+
);
2324+
assert_eq!(n.alignments, alignments_before, "alignments mutated");
2325+
assert_eq!(n.original, original_before, "original mutated");
2326+
assert_eq!(n.original_shift, shift_before, "original_shift mutated");
2327+
}
2328+
2329+
#[test]
2330+
fn lowercase_ascii_matches_unicode_path_byte_for_byte() {
2331+
// Cross-check against char::to_lowercase on every printable ASCII byte:
2332+
// the fast path must produce exactly the same bytes the slow path would.
2333+
let input: String = (0x20u8..0x7F).map(|b| b as char).collect();
2334+
let mut fast = NormalizedString::from(input.as_str());
2335+
fast.lowercase();
2336+
2337+
let expected: String = input.chars().flat_map(|c| c.to_lowercase()).collect();
2338+
assert_eq!(fast.get(), expected);
2339+
}
2340+
22922341
#[test]
22932342
fn test_append_after_clear() {
22942343
let mut n = NormalizedString::from("Hello");

tokenizers/src/utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub mod iter;
1818
pub mod padding;
1919
pub mod parallelism;
2020
pub(crate) mod progress;
21+
pub mod simd;
2122
pub mod truncation;
2223

2324
// Re-export ProgressFormat for public API

tokenizers/src/utils/simd.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//! SIMD helpers for ASCII fast paths.
2+
//!
3+
//! Each public function dispatches at runtime (via `is_x86_feature_detected!`
4+
//! on x86_64; aarch64 always has NEON under the stable target ABI) and falls
5+
//! back to a scalar implementation on other architectures.
6+
7+
/// Lowercase ASCII letters (`A`..=`Z` → `a`..=`z`) in place. Bytes outside that
8+
/// range are left untouched, so it is safe to call on any byte slice — but for
9+
/// best speed it should be guarded by a `is_ascii()` check upstream so the
10+
/// caller can also skip Unicode-aware logic.
11+
#[inline]
12+
pub fn ascii_lower(buf: &mut [u8]) {
13+
#[cfg(target_arch = "x86_64")]
14+
{
15+
if std::is_x86_feature_detected!("avx2") {
16+
unsafe { return ascii_lower_avx2(buf) };
17+
}
18+
// SSE2 is part of the x86_64 baseline; always available.
19+
unsafe { return ascii_lower_sse2(buf) };
20+
}
21+
#[cfg(target_arch = "aarch64")]
22+
{
23+
// NEON is mandatory on the stable aarch64 ABI; no runtime check needed.
24+
unsafe { return ascii_lower_neon(buf) };
25+
}
26+
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
27+
{
28+
ascii_lower_scalar(buf);
29+
}
30+
}
31+
32+
#[inline(always)]
33+
fn ascii_lower_scalar(buf: &mut [u8]) {
34+
for b in buf {
35+
if b.is_ascii_uppercase() {
36+
*b |= 0x20;
37+
}
38+
}
39+
}
40+
41+
#[cfg(target_arch = "x86_64")]
42+
#[target_feature(enable = "avx2")]
43+
unsafe fn ascii_lower_avx2(buf: &mut [u8]) {
44+
use std::arch::x86_64::*;
45+
46+
let a_minus_1 = _mm256_set1_epi8(b'A' as i8 - 1); // 0x40
47+
let z_plus_1 = _mm256_set1_epi8(b'Z' as i8 + 1); // 0x5B
48+
let case_bit = _mm256_set1_epi8(0x20);
49+
50+
let len = buf.len();
51+
let mut i = 0;
52+
while i + 32 <= len {
53+
let p = buf.as_mut_ptr().add(i) as *mut __m256i;
54+
let v = _mm256_loadu_si256(p as *const __m256i);
55+
// Signed compares are correct here because all uppercase ASCII bytes
56+
// are < 0x80; bytes >= 0x80 appear negative and are excluded from the mask.
57+
let gt_a = _mm256_cmpgt_epi8(v, a_minus_1); // v > 0x40
58+
let lt_z = _mm256_cmpgt_epi8(z_plus_1, v); // 0x5B > v
59+
let mask = _mm256_and_si256(gt_a, lt_z);
60+
let flip = _mm256_and_si256(mask, case_bit);
61+
let out = _mm256_xor_si256(v, flip);
62+
_mm256_storeu_si256(p, out);
63+
i += 32;
64+
}
65+
// Scalar tail (also covers buffers shorter than 32 bytes).
66+
ascii_lower_scalar(&mut buf[i..]);
67+
}
68+
69+
#[cfg(target_arch = "x86_64")]
70+
#[target_feature(enable = "sse2")]
71+
unsafe fn ascii_lower_sse2(buf: &mut [u8]) {
72+
use std::arch::x86_64::*;
73+
74+
let a_minus_1 = _mm_set1_epi8(b'A' as i8 - 1);
75+
let z_plus_1 = _mm_set1_epi8(b'Z' as i8 + 1);
76+
let case_bit = _mm_set1_epi8(0x20);
77+
78+
let len = buf.len();
79+
let mut i = 0;
80+
while i + 16 <= len {
81+
let p = buf.as_mut_ptr().add(i) as *mut __m128i;
82+
let v = _mm_loadu_si128(p as *const __m128i);
83+
let gt_a = _mm_cmpgt_epi8(v, a_minus_1);
84+
let lt_z = _mm_cmpgt_epi8(z_plus_1, v);
85+
let mask = _mm_and_si128(gt_a, lt_z);
86+
let flip = _mm_and_si128(mask, case_bit);
87+
let out = _mm_xor_si128(v, flip);
88+
_mm_storeu_si128(p, out);
89+
i += 16;
90+
}
91+
ascii_lower_scalar(&mut buf[i..]);
92+
}
93+
94+
#[cfg(target_arch = "aarch64")]
95+
unsafe fn ascii_lower_neon(buf: &mut [u8]) {
96+
use std::arch::aarch64::*;
97+
98+
let a_minus_1 = vdupq_n_u8(b'A' - 1);
99+
let z_plus_1 = vdupq_n_u8(b'Z' + 1);
100+
let case_bit = vdupq_n_u8(0x20);
101+
102+
let len = buf.len();
103+
let mut i = 0;
104+
while i + 16 <= len {
105+
let p = buf.as_mut_ptr().add(i);
106+
let v = vld1q_u8(p);
107+
// Unsigned compares on aarch64 — directly available.
108+
let gt_a = vcgtq_u8(v, a_minus_1); // v > A-1 → v >= A
109+
let lt_z = vcltq_u8(v, z_plus_1); // v < Z+1 → v <= Z
110+
let mask = vandq_u8(gt_a, lt_z);
111+
let flip = vandq_u8(mask, case_bit);
112+
let out = veorq_u8(v, flip);
113+
vst1q_u8(p, out);
114+
i += 16;
115+
}
116+
ascii_lower_scalar(&mut buf[i..]);
117+
}
118+
119+
#[cfg(test)]
120+
mod tests {
121+
use super::*;
122+
123+
fn scalar_reference(input: &[u8]) -> Vec<u8> {
124+
let mut out = input.to_vec();
125+
ascii_lower_scalar(&mut out);
126+
out
127+
}
128+
129+
#[test]
130+
fn empty() {
131+
let mut buf: [u8; 0] = [];
132+
ascii_lower(&mut buf);
133+
}
134+
135+
#[test]
136+
fn matches_scalar_on_random_ascii() {
137+
// Mix of upper, lower, digits, symbols across many lengths covering
138+
// sub-block, exact-block, and post-block tails for both 16- and 32-byte
139+
// SIMD widths.
140+
let mut data: Vec<u8> = (0..200u32)
141+
.map(|i| {
142+
let c = i as u8;
143+
// Cycle through printable ASCII.
144+
0x20 + (c % 0x5F)
145+
})
146+
.collect();
147+
let expected = scalar_reference(&data);
148+
ascii_lower(&mut data);
149+
assert_eq!(data, expected);
150+
}
151+
152+
#[test]
153+
fn matches_scalar_at_critical_lengths() {
154+
for len in [0, 1, 7, 15, 16, 17, 31, 32, 33, 47, 48, 63, 64, 65, 128, 129] {
155+
let mut data: Vec<u8> = (0..len as u8).map(|i| b'A' + (i % 26)).collect();
156+
let expected = scalar_reference(&data);
157+
ascii_lower(&mut data);
158+
assert_eq!(data, expected, "len={len}");
159+
}
160+
}
161+
162+
#[test]
163+
fn leaves_high_bytes_untouched() {
164+
// Ensures SIMD masks correctly exclude bytes >= 0x80 (UTF-8 continuation
165+
// bytes) — defensive even though the gate is meant to filter these out.
166+
let seed: Vec<u8> = vec![
167+
b'A', b'b', 0xC3, 0xA9, b'Z', 0xE2, 0x82, 0xAC, b'Q', 0x80, 0xFF,
168+
];
169+
// Repeat to cross SIMD block boundaries.
170+
let mut data = seed.repeat(8);
171+
let expected = scalar_reference(&data);
172+
ascii_lower(&mut data);
173+
assert_eq!(data, expected);
174+
}
175+
176+
#[test]
177+
fn idempotent() {
178+
let mut data = b"Hello, World! THE QUICK BROWN FOX 1234 JUMPS OVER 0 LAZY DOGS.".to_vec();
179+
ascii_lower(&mut data);
180+
let once = data.clone();
181+
ascii_lower(&mut data);
182+
assert_eq!(data, once);
183+
}
184+
}

0 commit comments

Comments
 (0)