Skip to content

Commit ef1c44b

Browse files
committed
simd 256 fixes
1 parent a7d93ca commit ef1c44b

File tree

1 file changed

+233
-53
lines changed

1 file changed

+233
-53
lines changed

portable/src/implementation/simd.rs

Lines changed: 233 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use core::simd::{
33
num::{SimdInt, SimdUint},
44
simd_swizzle, u8x16, LaneCount, Simd, SupportedLaneCount,
55
};
6+
use std::simd::u8x32;
67

78
use crate::{basic, compat};
89

@@ -20,7 +21,7 @@ const HAS_FAST_REDUCE_MAX: bool = true;
2021
)))]
2122
const HAS_FAST_REDUCE_MAX: bool = false;
2223

23-
const HAS_FAST_MASKED_LOAD: bool = false; // FIXME avx512, avx2 (?)
24+
const HAS_FAST_MASKED_LOAD: bool = false; // FIXME avx512, avx2 (32-bit chunks only?)
2425

2526
#[repr(C, align(32))]
2627
#[allow(dead_code)] // only used if a 128-bit SIMD implementation is used
@@ -116,12 +117,50 @@ impl SimdInputTrait for SimdInput<16, 4> {
116117
}
117118
}
118119

120+
impl SimdInputTrait for SimdInput<32, 2> {
121+
#[inline]
122+
fn new(s: &[u8]) -> Self {
123+
assert!(s.len() == 64);
124+
Self {
125+
vals: [u8x32::from_slice(&s[..32]), u8x32::from_slice(&s[32..64])],
126+
}
127+
}
128+
129+
#[inline]
130+
fn new_partial_masked_load(mut slice: &[u8]) -> Self {
131+
let val0 = load_masked_opt(slice);
132+
slice = &slice[slice.len().min(32)..];
133+
if slice.is_empty() {
134+
return Self {
135+
vals: [val0, u8x32::default()],
136+
};
137+
}
138+
let val1 = load_masked_opt(slice);
139+
Self { vals: [val0, val1] }
140+
}
141+
142+
#[inline]
143+
fn new_partial_copy(slice: &[u8]) -> Self {
144+
let mut buf = [0; 64];
145+
buf[..slice.len()].copy_from_slice(slice);
146+
Self::new(&buf)
147+
}
148+
149+
#[inline]
150+
fn is_ascii(&self) -> bool {
151+
(self.vals[0] | self.vals[1]).is_ascii()
152+
}
153+
}
154+
119155
#[inline]
120-
fn load_masked_opt(slice: &[u8]) -> Simd<u8, 16> {
121-
if slice.len() > 15 {
122-
u8x16::from_slice(&slice[..16])
156+
fn load_masked_opt<const N: usize>(slice: &[u8]) -> Simd<u8, N>
157+
where
158+
LaneCount<N>: SupportedLaneCount,
159+
{
160+
if slice.len() > N - 1 {
161+
Simd::<u8, N>::from_slice(&slice[..N])
123162
} else {
124-
u8x16::load_or_default(slice)
163+
Simd::<u8, N>::load_or_default(slice)
125164
}
126165
}
127166

@@ -134,10 +173,34 @@ where
134173
pub(crate) error: Simd<u8, N>, // FIXME: should be a mask?
135174
}
136175

176+
trait Lookup16 {
177+
#[expect(clippy::too_many_arguments)]
178+
fn lookup_16(
179+
self,
180+
v0: u8,
181+
v1: u8,
182+
v2: u8,
183+
v3: u8,
184+
v4: u8,
185+
v5: u8,
186+
v6: u8,
187+
v7: u8,
188+
v8: u8,
189+
v9: u8,
190+
v10: u8,
191+
v11: u8,
192+
v12: u8,
193+
v13: u8,
194+
v14: u8,
195+
v15: u8,
196+
) -> Self;
197+
}
198+
137199
trait SimdU8Value<const N: usize>
138200
where
139201
LaneCount<N>: SupportedLaneCount,
140202
Self: Copy,
203+
Self: Lookup16,
141204
{
142205
#[expect(clippy::too_many_arguments)]
143206
fn from_32_cut_off_leading(
@@ -195,27 +258,6 @@ where
195258
v15: u8,
196259
) -> Self;
197260

198-
#[expect(clippy::too_many_arguments)]
199-
fn lookup_16(
200-
self,
201-
v0: u8,
202-
v1: u8,
203-
v2: u8,
204-
v3: u8,
205-
v4: u8,
206-
v5: u8,
207-
v6: u8,
208-
v7: u8,
209-
v8: u8,
210-
v9: u8,
211-
v10: u8,
212-
v11: u8,
213-
v12: u8,
214-
v13: u8,
215-
v14: u8,
216-
v15: u8,
217-
) -> Self;
218-
219261
// const generics would be more awkward and verbose with the current
220262
// portable SIMD swizzle implementation and compiler limitations.
221263
fn prev1(self, prev: Self) -> Self;
@@ -290,6 +332,48 @@ impl SimdU8Value<16> for u8x16 {
290332
])
291333
}
292334

335+
#[inline]
336+
fn prev1(self, prev: Self) -> Self {
337+
simd_swizzle!(
338+
self,
339+
prev,
340+
[31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,]
341+
)
342+
}
343+
344+
#[inline]
345+
fn prev2(self, prev: Self) -> Self {
346+
simd_swizzle!(
347+
self,
348+
prev,
349+
[30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,]
350+
)
351+
}
352+
353+
#[inline]
354+
fn prev3(self, prev: Self) -> Self {
355+
simd_swizzle!(
356+
self,
357+
prev,
358+
[29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,]
359+
)
360+
}
361+
362+
#[inline]
363+
fn is_ascii(self) -> bool {
364+
if HAS_FAST_REDUCE_MAX {
365+
self.reduce_max() < 0b1000_0000
366+
} else {
367+
(self & Self::splat(0b1000_0000)) == Self::splat(0)
368+
}
369+
}
370+
}
371+
372+
impl<const N: usize> Lookup16 for Simd<u8, N>
373+
where
374+
Self: SimdU8Value<N>,
375+
LaneCount<N>: SupportedLaneCount,
376+
{
293377
#[inline]
294378
fn lookup_16(
295379
self,
@@ -317,13 +401,84 @@ impl SimdU8Value<16> for u8x16 {
317401
);
318402
src.swizzle_dyn(self)
319403
}
404+
}
405+
406+
impl SimdU8Value<32> for u8x32 {
407+
#[inline]
408+
fn from_32_cut_off_leading(
409+
v0: u8,
410+
v1: u8,
411+
v2: u8,
412+
v3: u8,
413+
v4: u8,
414+
v5: u8,
415+
v6: u8,
416+
v7: u8,
417+
v8: u8,
418+
v9: u8,
419+
v10: u8,
420+
v11: u8,
421+
v12: u8,
422+
v13: u8,
423+
v14: u8,
424+
v15: u8,
425+
v16: u8,
426+
v17: u8,
427+
v18: u8,
428+
v19: u8,
429+
v20: u8,
430+
v21: u8,
431+
v22: u8,
432+
v23: u8,
433+
v24: u8,
434+
v25: u8,
435+
v26: u8,
436+
v27: u8,
437+
v28: u8,
438+
v29: u8,
439+
v30: u8,
440+
v31: u8,
441+
) -> Self {
442+
Self::from_array([
443+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18,
444+
v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31,
445+
])
446+
}
447+
448+
#[inline]
449+
fn repeat_16(
450+
v0: u8,
451+
v1: u8,
452+
v2: u8,
453+
v3: u8,
454+
v4: u8,
455+
v5: u8,
456+
v6: u8,
457+
v7: u8,
458+
v8: u8,
459+
v9: u8,
460+
v10: u8,
461+
v11: u8,
462+
v12: u8,
463+
v13: u8,
464+
v14: u8,
465+
v15: u8,
466+
) -> Self {
467+
Self::from_array([
468+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v0, v1, v2, v3,
469+
v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15,
470+
])
471+
}
320472

321473
#[inline]
322474
fn prev1(self, prev: Self) -> Self {
323475
simd_swizzle!(
324476
self,
325477
prev,
326-
[31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,]
478+
[
479+
63, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
480+
22, 23, 24, 25, 26, 27, 28, 29, 30
481+
]
327482
)
328483
}
329484

@@ -332,7 +487,10 @@ impl SimdU8Value<16> for u8x16 {
332487
simd_swizzle!(
333488
self,
334489
prev,
335-
[30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,]
490+
[
491+
62, 63, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
492+
21, 22, 23, 24, 25, 26, 27, 28, 29
493+
]
336494
)
337495
}
338496

@@ -341,7 +499,10 @@ impl SimdU8Value<16> for u8x16 {
341499
simd_swizzle!(
342500
self,
343501
prev,
344-
[29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,]
502+
[
503+
61, 62, 63, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
504+
20, 21, 22, 23, 24, 25, 26, 27, 28,
505+
]
345506
)
346507
}
347508

@@ -654,25 +815,6 @@ where
654815
}
655816
}
656817

657-
/// Validation implementation for CPUs supporting the SIMD extension (see module).
658-
///
659-
/// # Errors
660-
/// Returns the zero-sized [`basic::Utf8Error`] on failure.
661-
#[inline]
662-
pub fn validate_utf8_basic(input: &[u8]) -> core::result::Result<(), basic::Utf8Error> {
663-
Utf8CheckAlgorithm::<16, 4>::validate_utf8_basic(input)
664-
}
665-
666-
/// Validation implementation for CPUs supporting the SIMD extension (see module).
667-
///
668-
/// # Errors
669-
/// Returns [`compat::Utf8Error`] with detailed error information on failure.
670-
#[inline]
671-
pub fn validate_utf8_compat(input: &[u8]) -> core::result::Result<(), compat::Utf8Error> {
672-
Utf8CheckAlgorithm::<16, 4>::validate_utf8_compat_simd0(input)
673-
.map_err(|idx| super::get_compat_error(input, idx))
674-
}
675-
676818
/// Low-level implementation of the [`basic::imp::Utf8Validator`] trait.
677819
///
678820
/// This is implementation requires CPU SIMD features specified by the module it resides in.
@@ -819,20 +961,58 @@ impl basic::imp::ChunkedUtf8Validator for ChunkedUtf8ValidatorImp {
819961
}
820962
}
821963

822-
pub(crate) use v128 as auto; // FIXME: select based on target feature
964+
pub(crate) use v256 as auto; // FIXME: select based on target feature
823965

824966
pub(crate) mod v128 {
825-
pub use super::validate_utf8_basic;
826-
pub use super::validate_utf8_compat;
967+
/// Validation implementation for CPUs supporting the SIMD extension (see module).
968+
///
969+
/// # Errors
970+
/// Returns the zero-sized [`basic::Utf8Error`] on failure.
971+
#[inline]
972+
pub fn validate_utf8_basic(input: &[u8]) -> core::result::Result<(), crate::basic::Utf8Error> {
973+
super::Utf8CheckAlgorithm::<16, 4>::validate_utf8_basic(input)
974+
}
975+
976+
/// Validation implementation for CPUs supporting the SIMD extension (see module).
977+
///
978+
/// # Errors
979+
/// Returns [`compat::Utf8Error`] with detailed error information on failure.
980+
#[inline]
981+
pub fn validate_utf8_compat(
982+
input: &[u8],
983+
) -> core::result::Result<(), crate::compat::Utf8Error> {
984+
super::Utf8CheckAlgorithm::<16, 4>::validate_utf8_compat_simd0(input)
985+
.map_err(|idx| crate::implementation::get_compat_error(input, idx))
986+
}
987+
827988
#[cfg(feature = "public_imp")]
828989
pub use super::ChunkedUtf8ValidatorImp;
829990
#[cfg(feature = "public_imp")]
830991
pub use super::Utf8ValidatorImp;
831992
}
832993

833994
pub(crate) mod v256 {
834-
pub use super::validate_utf8_basic;
835-
pub use super::validate_utf8_compat;
995+
/// Validation implementation for CPUs supporting the SIMD extension (see module).
996+
///
997+
/// # Errors
998+
/// Returns the zero-sized [`basic::Utf8Error`] on failure.
999+
#[inline]
1000+
pub fn validate_utf8_basic(input: &[u8]) -> core::result::Result<(), crate::basic::Utf8Error> {
1001+
super::Utf8CheckAlgorithm::<32, 2>::validate_utf8_basic(input)
1002+
}
1003+
1004+
/// Validation implementation for CPUs supporting the SIMD extension (see module).
1005+
///
1006+
/// # Errors
1007+
/// Returns [`compat::Utf8Error`] with detailed error information on failure.
1008+
#[inline]
1009+
pub fn validate_utf8_compat(
1010+
input: &[u8],
1011+
) -> core::result::Result<(), crate::compat::Utf8Error> {
1012+
super::Utf8CheckAlgorithm::<32, 2>::validate_utf8_compat_simd0(input)
1013+
.map_err(|idx| crate::implementation::get_compat_error(input, idx))
1014+
}
1015+
8361016
#[cfg(feature = "public_imp")]
8371017
pub use super::ChunkedUtf8ValidatorImp;
8381018
#[cfg(feature = "public_imp")]

0 commit comments

Comments
 (0)