|
| 1 | +//! SIMD helpers for ASCII fast paths. |
| 2 | +//! |
| 3 | +//! Each public function dispatches at runtime (via `is_x86_feature_detected!` |
| 4 | +//! on x86_64; aarch64 always has NEON under the stable target ABI) and falls |
| 5 | +//! back to a scalar implementation on other architectures. |
| 6 | +
|
| 7 | +/// Lowercase ASCII letters (`A`..=`Z` → `a`..=`z`) in place. Bytes outside that |
| 8 | +/// range are left untouched, so it is safe to call on any byte slice — but for |
| 9 | +/// best speed it should be guarded by a `is_ascii()` check upstream so the |
| 10 | +/// caller can also skip Unicode-aware logic. |
| 11 | +#[inline] |
| 12 | +pub fn ascii_lower(buf: &mut [u8]) { |
| 13 | + #[cfg(target_arch = "x86_64")] |
| 14 | + { |
| 15 | + if std::is_x86_feature_detected!("avx2") { |
| 16 | + unsafe { return ascii_lower_avx2(buf) }; |
| 17 | + } |
| 18 | + // SSE2 is part of the x86_64 baseline; always available. |
| 19 | + unsafe { return ascii_lower_sse2(buf) }; |
| 20 | + } |
| 21 | + #[cfg(target_arch = "aarch64")] |
| 22 | + { |
| 23 | + // NEON is mandatory on the stable aarch64 ABI; no runtime check needed. |
| 24 | + unsafe { return ascii_lower_neon(buf) }; |
| 25 | + } |
| 26 | + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] |
| 27 | + { |
| 28 | + ascii_lower_scalar(buf); |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +#[inline(always)] |
| 33 | +fn ascii_lower_scalar(buf: &mut [u8]) { |
| 34 | + for b in buf { |
| 35 | + if b.is_ascii_uppercase() { |
| 36 | + *b |= 0x20; |
| 37 | + } |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +#[cfg(target_arch = "x86_64")] |
| 42 | +#[target_feature(enable = "avx2")] |
| 43 | +unsafe fn ascii_lower_avx2(buf: &mut [u8]) { |
| 44 | + use std::arch::x86_64::*; |
| 45 | + |
| 46 | + let a_minus_1 = _mm256_set1_epi8(b'A' as i8 - 1); // 0x40 |
| 47 | + let z_plus_1 = _mm256_set1_epi8(b'Z' as i8 + 1); // 0x5B |
| 48 | + let case_bit = _mm256_set1_epi8(0x20); |
| 49 | + |
| 50 | + let len = buf.len(); |
| 51 | + let mut i = 0; |
| 52 | + while i + 32 <= len { |
| 53 | + let p = buf.as_mut_ptr().add(i) as *mut __m256i; |
| 54 | + let v = _mm256_loadu_si256(p as *const __m256i); |
| 55 | + // Signed compares are correct here because all uppercase ASCII bytes |
| 56 | + // are < 0x80; bytes >= 0x80 appear negative and are excluded from the mask. |
| 57 | + let gt_a = _mm256_cmpgt_epi8(v, a_minus_1); // v > 0x40 |
| 58 | + let lt_z = _mm256_cmpgt_epi8(z_plus_1, v); // 0x5B > v |
| 59 | + let mask = _mm256_and_si256(gt_a, lt_z); |
| 60 | + let flip = _mm256_and_si256(mask, case_bit); |
| 61 | + let out = _mm256_xor_si256(v, flip); |
| 62 | + _mm256_storeu_si256(p, out); |
| 63 | + i += 32; |
| 64 | + } |
| 65 | + // Scalar tail (also covers buffers shorter than 32 bytes). |
| 66 | + ascii_lower_scalar(&mut buf[i..]); |
| 67 | +} |
| 68 | + |
| 69 | +#[cfg(target_arch = "x86_64")] |
| 70 | +#[target_feature(enable = "sse2")] |
| 71 | +unsafe fn ascii_lower_sse2(buf: &mut [u8]) { |
| 72 | + use std::arch::x86_64::*; |
| 73 | + |
| 74 | + let a_minus_1 = _mm_set1_epi8(b'A' as i8 - 1); |
| 75 | + let z_plus_1 = _mm_set1_epi8(b'Z' as i8 + 1); |
| 76 | + let case_bit = _mm_set1_epi8(0x20); |
| 77 | + |
| 78 | + let len = buf.len(); |
| 79 | + let mut i = 0; |
| 80 | + while i + 16 <= len { |
| 81 | + let p = buf.as_mut_ptr().add(i) as *mut __m128i; |
| 82 | + let v = _mm_loadu_si128(p as *const __m128i); |
| 83 | + let gt_a = _mm_cmpgt_epi8(v, a_minus_1); |
| 84 | + let lt_z = _mm_cmpgt_epi8(z_plus_1, v); |
| 85 | + let mask = _mm_and_si128(gt_a, lt_z); |
| 86 | + let flip = _mm_and_si128(mask, case_bit); |
| 87 | + let out = _mm_xor_si128(v, flip); |
| 88 | + _mm_storeu_si128(p, out); |
| 89 | + i += 16; |
| 90 | + } |
| 91 | + ascii_lower_scalar(&mut buf[i..]); |
| 92 | +} |
| 93 | + |
| 94 | +#[cfg(target_arch = "aarch64")] |
| 95 | +unsafe fn ascii_lower_neon(buf: &mut [u8]) { |
| 96 | + use std::arch::aarch64::*; |
| 97 | + |
| 98 | + let a_minus_1 = vdupq_n_u8(b'A' - 1); |
| 99 | + let z_plus_1 = vdupq_n_u8(b'Z' + 1); |
| 100 | + let case_bit = vdupq_n_u8(0x20); |
| 101 | + |
| 102 | + let len = buf.len(); |
| 103 | + let mut i = 0; |
| 104 | + while i + 16 <= len { |
| 105 | + let p = buf.as_mut_ptr().add(i); |
| 106 | + let v = vld1q_u8(p); |
| 107 | + // Unsigned compares on aarch64 — directly available. |
| 108 | + let gt_a = vcgtq_u8(v, a_minus_1); // v > A-1 → v >= A |
| 109 | + let lt_z = vcltq_u8(v, z_plus_1); // v < Z+1 → v <= Z |
| 110 | + let mask = vandq_u8(gt_a, lt_z); |
| 111 | + let flip = vandq_u8(mask, case_bit); |
| 112 | + let out = veorq_u8(v, flip); |
| 113 | + vst1q_u8(p, out); |
| 114 | + i += 16; |
| 115 | + } |
| 116 | + ascii_lower_scalar(&mut buf[i..]); |
| 117 | +} |
| 118 | + |
| 119 | +#[cfg(test)] |
| 120 | +mod tests { |
| 121 | + use super::*; |
| 122 | + |
| 123 | + fn scalar_reference(input: &[u8]) -> Vec<u8> { |
| 124 | + let mut out = input.to_vec(); |
| 125 | + ascii_lower_scalar(&mut out); |
| 126 | + out |
| 127 | + } |
| 128 | + |
| 129 | + #[test] |
| 130 | + fn empty() { |
| 131 | + let mut buf: [u8; 0] = []; |
| 132 | + ascii_lower(&mut buf); |
| 133 | + } |
| 134 | + |
| 135 | + #[test] |
| 136 | + fn matches_scalar_on_random_ascii() { |
| 137 | + // Mix of upper, lower, digits, symbols across many lengths covering |
| 138 | + // sub-block, exact-block, and post-block tails for both 16- and 32-byte |
| 139 | + // SIMD widths. |
| 140 | + let mut data: Vec<u8> = (0..200u32) |
| 141 | + .map(|i| { |
| 142 | + let c = i as u8; |
| 143 | + // Cycle through printable ASCII. |
| 144 | + 0x20 + (c % 0x5F) |
| 145 | + }) |
| 146 | + .collect(); |
| 147 | + let expected = scalar_reference(&data); |
| 148 | + ascii_lower(&mut data); |
| 149 | + assert_eq!(data, expected); |
| 150 | + } |
| 151 | + |
| 152 | + #[test] |
| 153 | + fn matches_scalar_at_critical_lengths() { |
| 154 | + for len in [0, 1, 7, 15, 16, 17, 31, 32, 33, 47, 48, 63, 64, 65, 128, 129] { |
| 155 | + let mut data: Vec<u8> = (0..len as u8).map(|i| b'A' + (i % 26)).collect(); |
| 156 | + let expected = scalar_reference(&data); |
| 157 | + ascii_lower(&mut data); |
| 158 | + assert_eq!(data, expected, "len={len}"); |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + #[test] |
| 163 | + fn leaves_high_bytes_untouched() { |
| 164 | + // Ensures SIMD masks correctly exclude bytes >= 0x80 (UTF-8 continuation |
| 165 | + // bytes) — defensive even though the gate is meant to filter these out. |
| 166 | + let seed: Vec<u8> = vec![ |
| 167 | + b'A', b'b', 0xC3, 0xA9, b'Z', 0xE2, 0x82, 0xAC, b'Q', 0x80, 0xFF, |
| 168 | + ]; |
| 169 | + // Repeat to cross SIMD block boundaries. |
| 170 | + let mut data = seed.repeat(8); |
| 171 | + let expected = scalar_reference(&data); |
| 172 | + ascii_lower(&mut data); |
| 173 | + assert_eq!(data, expected); |
| 174 | + } |
| 175 | + |
| 176 | + #[test] |
| 177 | + fn idempotent() { |
| 178 | + let mut data = b"Hello, World! THE QUICK BROWN FOX 1234 JUMPS OVER 0 LAZY DOGS.".to_vec(); |
| 179 | + ascii_lower(&mut data); |
| 180 | + let once = data.clone(); |
| 181 | + ascii_lower(&mut data); |
| 182 | + assert_eq!(data, once); |
| 183 | + } |
| 184 | +} |
0 commit comments