diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d63685..06ae85e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,5 +78,6 @@ jobs: uses: ./.github/actions/test with: toolchain: nightly + target: "x86_64-unknown-linux-gnu" rustflags: "-Z sanitizer=address" rustdocflags: "-Z sanitizer=address" diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ad8704..0d54ebe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ to [Semantic Versioning][]. ## [Unreleased] +- Trait bounds removed from generic types, bounds are only required for impls +- Added SIMD implementation of `Bytes`/`AsciiChars` for aarch64 neon +- Added the the ability to efficiently iterate over the indexes of matching elements (`Bytes::iter`/`AsciiChars::iter`) + ## [0.5.3] - 2022-07-06 - Fix buffer overflows in find. (#55) diff --git a/Cargo.toml b/Cargo.toml index 4970861..c28fc0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,22 @@ documentation = "https://docs.rs/jetscii/" license = "MIT OR Apache-2.0" +edition = "2018" + [features] +# This feature is now a no-op, but we keep it around for backwards compatibility benchmarks = [] pattern = [] [dev-dependencies] +aho-corasick = "1.1.0" proptest = "1.0.0" lazy_static = "1.0.0" region = "3.0.0" -memmap = "0.7.0" +memmap2 = "0.9.0" +criterion = "0.5.0" +memchr = "2.0.0" + +[[bench]] +name = "benchmarks" +harness = false \ No newline at end of file diff --git a/README.md b/README.md index ea3397d..2f1861c 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,7 @@ characters or byte slices for sets of bytes. ### Searching for a set of ASCII characters ```rust -#[macro_use] -extern crate jetscii; +use jetscii::ascii_chars; fn main() { let part_number = "86-J52:rev1"; @@ -23,8 +22,7 @@ fn main() { ### Searching for a set of bytes ```rust -#[macro_use] -extern crate jetscii; +use jetscii::bytes; fn main() { let raw_data = [0x00, 0x01, 0x10, 0xFF, 0x42]; diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs new file mode 100644 index 0000000..5ced6fc --- /dev/null +++ b/benches/benchmarks.rs @@ -0,0 +1,322 @@ +use criterion::{criterion_group, criterion_main, BatchSize, Criterion, Throughput}; +use jetscii::{ascii_chars, AsciiCharsConst, SubstringConst}; +use std::hint::black_box; +use std::sync::OnceLock; + +static SPACE: OnceLock = OnceLock::new(); + +fn space() -> &'static AsciiCharsConst { + SPACE.get_or_init(|| ascii_chars!(' ')) +} + +static XML_DELIM_3: OnceLock = OnceLock::new(); + +fn xml_delim_3() -> &'static AsciiCharsConst { + XML_DELIM_3.get_or_init(|| ascii_chars!('<', '>', '&')) +} + +static XML_DELIM_5: OnceLock = OnceLock::new(); + +fn xml_delim_5() -> &'static AsciiCharsConst { + XML_DELIM_5.get_or_init(|| ascii_chars!('<', '>', '&', '\'', '"')) +} + +static BIG_16: OnceLock = OnceLock::new(); + +fn big_16() -> &'static AsciiCharsConst { + BIG_16.get_or_init(|| { + ascii_chars!('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P') + }) +} + +static SUBSTRING: OnceLock = OnceLock::new(); + +fn substring() -> &'static SubstringConst { + SUBSTRING.get_or_init(|| SubstringConst::new("xyzzy")) +} + +fn prefix_string() -> String { + "a".repeat(5 * 1024 * 1024) +} + +fn spaces(c: &mut Criterion) { + let mut haystack = prefix_string(); + haystack.push(' '); + let haystack = black_box(haystack); + + let mut group = c.benchmark_group("find_last_space"); + group.throughput(Throughput::Bytes(haystack.len() as u64)); + + group.bench_function("ascii_chars", |b| { + let space = space(); + b.iter(|| space.find(&haystack)); + }); + group.bench_function("teddy", |b| { + let searcher = aho_corasick::packed::Searcher::new([" "]).unwrap(); + b.iter(|| searcher.find(&haystack).map(|m| m.start())); + }); + group.bench_function("memchr", |b| { + b.iter(|| memchr::memchr(b' ', haystack.as_bytes())); + }); +} + +fn xml3(c: &mut Criterion) { + let mut haystack = prefix_string(); + haystack.push('&'); + let haystack = black_box(haystack); + + let mut group = c.benchmark_group("find_xml_3"); + group.throughput(Throughput::Bytes(haystack.len() as u64)); + + group.bench_function("ascii_chars", |b| { + let xml_delim_3 = xml_delim_3(); + b.iter(|| xml_delim_3.find(&haystack)); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + haystack + .bytes() + .position(|c| c == b'<' || c == b'>' || c == b'&') + }); + }); + group.bench_function("teddy", |b| { + let searcher = aho_corasick::packed::Searcher::new(["<", ">", "&"]).unwrap(); + b.iter(|| searcher.find(&haystack).map(|m| m.start())); + }); + group.bench_function("memchr", |b| { + b.iter(|| memchr::memchr3(b'<', b'>', b'&', haystack.as_bytes())); + }); +} + +fn xml5(c: &mut Criterion) { + let mut haystack = prefix_string(); + haystack.push('"'); + let haystack = black_box(haystack); + + let mut group = c.benchmark_group("find_xml_5"); + group.throughput(Throughput::Bytes(haystack.len() as u64)); + + group.bench_function("ascii_chars", |b| { + let xml_delim_5 = xml_delim_5(); + b.iter(|| xml_delim_5.find(&haystack)); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + haystack + .bytes() + .position(|c| c == b'<' || c == b'>' || c == b'&' || c == b'\'' || c == b'"') + }); + }); + group.bench_function("teddy", |b| { + let searcher = aho_corasick::packed::Searcher::new(["<", ">", "&", "'", "\""]).unwrap(); + b.iter(|| searcher.find(&haystack).map(|m| m.start())); + }); +} + +fn big_16_benches(c: &mut Criterion) { + let mut haystack = prefix_string(); + haystack.push('P'); + let haystack = black_box(haystack); + + let mut group = c.benchmark_group("find_big_16"); + group.throughput(Throughput::Bytes(haystack.len() as u64)); + + group.bench_function("ascii_chars", |b| { + let big_16 = big_16(); + b.iter(|| big_16.find(&haystack)); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + haystack.bytes().position(|c| { + c == b'A' + || c == b'B' + || c == b'C' + || c == b'D' + || c == b'E' + || c == b'F' + || c == b'G' + || c == b'H' + || c == b'I' + || c == b'J' + || c == b'K' + || c == b'L' + || c == b'M' + || c == b'N' + || c == b'O' + || c == b'P' + }) + }); + }); + group.bench_function("teddy", |b| { + let searcher = aho_corasick::packed::Searcher::new( + b"ABCDEFGHIJKLMNOP".iter().map(|b| std::array::from_ref(b)), + ) + .unwrap(); + b.iter(|| searcher.find(&haystack).map(|m| m.start())); + }); + + group.finish(); + + let mut haystack = prefix_string(); + haystack.insert(0, 'P'); + let haystack = black_box(haystack); + let mut group = c.benchmark_group("find_big_16_early_return"); + group.throughput(Throughput::Bytes(1)); + + group.bench_function("ascii_chars", |b| { + let big_16 = big_16(); + b.iter(|| big_16.find(&haystack)); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + haystack.bytes().position(|c| { + c == b'A' + || c == b'B' + || c == b'C' + || c == b'D' + || c == b'E' + || c == b'F' + || c == b'G' + || c == b'H' + || c == b'I' + || c == b'J' + || c == b'K' + || c == b'L' + || c == b'M' + || c == b'N' + || c == b'O' + || c == b'P' + }) + }); + }); + group.bench_function("teddy", |b| { + let searcher = aho_corasick::packed::Searcher::new( + b"ABCDEFGHIJKLMNOP".iter().map(|b| std::array::from_ref(b)), + ) + .unwrap(); + b.iter(|| searcher.find(&haystack).map(|m| m.start())); + }); +} + +fn substr(c: &mut Criterion) { + let mut haystack = prefix_string(); + haystack.push_str("xyzzy"); + let haystack = black_box(haystack); + + let mut group = c.benchmark_group("find_substring"); + group.throughput(Throughput::Bytes(haystack.len() as u64)); + + group.bench_function("substring", |b| { + let substring = substring(); + b.iter(|| substring.find(&haystack)); + }); + group.bench_function("stdlib_find_string", |b| { + b.iter(|| haystack.find("xyzzy")); + }); + group.bench_function("memchr", |b| { + let finder = memchr::memmem::Finder::new(b"xyzzy"); + b.iter(|| finder.find(haystack.as_bytes())); + }); +} + +fn iterate_xml_many_match(c: &mut Criterion) { + let haystack = black_box(include_str!("plant_catalog.xml")); + let mut group = c.benchmark_group("iterate_xml_3"); + + group.throughput(Throughput::Bytes(haystack.len() as u64)); + group.bench_function("ascii_chars", |b| { + let xml_delim_3 = xml_delim_3(); + b.iter_batched( + || xml_delim_3.as_bytes().iter(haystack.as_bytes()), + |iter| { + for offset in iter { + black_box(offset); + } + }, + BatchSize::SmallInput, + ); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + let mut haystack = &haystack[..]; + let mut offset = 0; + while let Some(pos) = haystack + .bytes() + .position(|c| c == b'<' || c == b'>' || c == b'&') + { + haystack = &haystack[pos + 1..]; + offset += pos; + black_box(offset); + } + }); + }); + group.bench_function("memchr", |b| { + b.iter_batched( + || memchr::memchr3_iter(b'<', b'>', b'&', haystack.as_bytes()), + |iter| { + for offset in iter { + black_box(offset); + } + }, + BatchSize::SmallInput, + ); + }); + group.finish(); +} + +fn iterate_few_match(c: &mut Criterion) { + let haystack = black_box(include_str!("plant_catalog.xml")); + let mut group = c.benchmark_group("iterate_few_matches"); + let chars: AsciiCharsConst = ascii_chars!(b'?', b'-', b'\0'); + + group.throughput(Throughput::Bytes(haystack.len() as u64)); + group.bench_function("ascii_chars", |b| { + b.iter(|| { + let mut haystack = &haystack[..]; + let mut offset = 0; + while let Some(pos) = chars.find(haystack) { + haystack = &haystack[pos + 1..]; + offset += pos; + black_box(offset); + } + }); + }); + group.bench_function("stdlib_iter_position", |b| { + b.iter(|| { + let mut haystack = &haystack[..]; + let mut offset = 0; + while let Some(pos) = haystack + .bytes() + .position(|c| c == b'?' || c == b'-' || c == b'\0') + { + haystack = &haystack[pos + 1..]; + offset += pos; + black_box(offset); + } + }); + }); + group.bench_function("memchr", |b| { + b.iter_batched( + || memchr::memchr3_iter(b'?', b'-', b'\0', haystack.as_bytes()), + |iter| { + for offset in iter { + black_box(offset); + } + }, + BatchSize::SmallInput, + ); + }); + group.finish(); +} + +criterion_group!( + benches, + spaces, + xml3, + xml5, + big_16_benches, + substr, + iterate_xml_many_match, + iterate_few_match, +); +criterion_main!(benches); diff --git a/benches/plant_catalog.xml b/benches/plant_catalog.xml new file mode 100644 index 0000000..4275db9 --- /dev/null +++ b/benches/plant_catalog.xml @@ -0,0 +1,291 @@ + + + + Bloodroot + Sanguinaria canadensis + 4 + Mostly Shady + $2.44 + 031599 + + + Columbine + Aquilegia canadensis + 3 + Mostly Shady + $9.37 + 030699 + + + Marsh Marigold + Caltha palustris + 4 + Mostly Sunny + $6.81 + 051799 + + + Cowslip + Caltha palustris + 4 + Mostly Shady + $9.90 + 030699 + + + Dutchman's-Breeches + Dicentra cucullaria + 3 + Mostly Shady + $6.44 + 012099 + + + Ginger, Wild + Asarum canadense + 3 + Mostly Shady + $9.03 + 041899 + + + Hepatica + Hepatica americana + 4 + Mostly Shady + $4.45 + 012699 + + + Liverleaf + Hepatica americana + 4 + Mostly Shady + $3.99 + 010299 + + + Jack-In-The-Pulpit + Arisaema triphyllum + 4 + Mostly Shady + $3.23 + 020199 + + + Mayapple + Podophyllum peltatum + 3 + Mostly Shady + $2.98 + 060599 + + + Phlox, Woodland + Phlox divaricata + 3 + Sun or Shade + $2.80 + 012299 + + + Phlox, Blue + Phlox divaricata + 3 + Sun or Shade + $5.59 + 021699 + + + Spring-Beauty + Claytonia Virginica + 7 + Mostly Shady + $6.59 + 020199 + + + Trillium + Trillium grandiflorum + 5 + Sun or Shade + $3.90 + 042999 + + + Wake Robin + Trillium grandiflorum + 5 + Sun or Shade + $3.20 + 022199 + + + Violet, Dog-Tooth + Erythronium americanum + 4 + Shade + $9.04 + 020199 + + + Trout Lily + Erythronium americanum + 4 + Shade + $6.94 + 032499 + + + Adder's-Tongue + Erythronium americanum + 4 + Shade + $9.58 + 041399 + + + Anemone + Anemone blanda + 6 + Mostly Shady + $8.86 + 122698 + + + Grecian Windflower + Anemone blanda + 6 + Mostly Shady + $9.16 + 071099 + + + Bee Balm + Monarda didyma + 4 + Shade + $4.59 + 050399 + + + Bergamot + Monarda didyma + 4 + Shade + $7.16 + 042799 + + + Black-Eyed Susan + Rudbeckia hirta + Annual + Sunny + $9.80 + 061899 + + + Buttercup + Ranunculus + 4 + Shade + $2.57 + 061099 + + + Crowfoot + Ranunculus + 4 + Shade + $9.34 + 040399 + + + Butterfly Weed + Asclepias tuberosa + Annual + Sunny + $2.78 + 063099 + + + Cinquefoil + Potentilla + Annual + Shade + $7.06 + 052599 + + + Primrose + Oenothera + 3 - 5 + Sunny + $6.56 + 013099 + + + Gentian + Gentiana + 4 + Sun or Shade + $7.81 + 051899 + + + Blue Gentian + Gentiana + 4 + Sun or Shade + $8.56 + 050299 + + + Jacob's Ladder + Polemonium caeruleum + Annual + Shade + $9.26 + 022199 + + + Greek Valerian + Polemonium caeruleum + Annual + Shade + $4.36 + 071499 + + + California Poppy + Eschscholzia californica + Annual + Sun + $7.89 + 032799 + + + Shooting Star + Dodecatheon + Annual + Mostly Shady + $8.60 + 051399 + + + Snakeroot + Cimicifuga + Annual + Shade + $5.63 + 071199 + + + Cardinal Flower + Lobelia cardinalis + 2 + Shade + $3.02 + 022299 + + diff --git a/build.rs b/build.rs index 20a8d92..197952b 100644 --- a/build.rs +++ b/build.rs @@ -4,28 +4,11 @@ use std::io::prelude::*; use std::path::{Path, PathBuf}; fn main() { - cfg(); macros(); simd_macros(); println!("cargo:rerun-if-changed=build.rs"); } -fn cfg() { - let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default(); - let target_feature = env::var("CARGO_CFG_TARGET_FEATURE").unwrap_or_default(); - - let ok_arch = matches!(&*target_arch, "x86" | "x86_64"); - let sse4_2_guaranteed = target_feature.split(',').any(|f| f == "sse4.2"); - - if sse4_2_guaranteed { - println!(r#"cargo:rustc-cfg=jetscii_sse4_2="yes""#); - } else if ok_arch { - println!(r#"cargo:rustc-cfg=jetscii_sse4_2="maybe""#); - } else { - println!(r#"cargo:rustc-cfg=jetscii_sse4_2="no""#); - } -} - fn macros() { let mut base: PathBuf = env::var_os("OUT_DIR").unwrap().into(); base.push("src"); @@ -134,7 +117,7 @@ fn simd_macros() { let array = array.join(", "); format!( - "({}) => ($crate::simd::Bytes::new([{}], {}));\n", + "({}) => ($crate::simd::x86::Bytes::new([{}], {}));\n", args, array, max ) }) diff --git a/src/fallback.rs b/src/fallback.rs index 0f7b21b..9201bbf 100644 --- a/src/fallback.rs +++ b/src/fallback.rs @@ -1,10 +1,7 @@ // TODO: Try boxing the closure to see if we can hide the type // TODO: Or maybe use a function pointer? -pub struct Bytes -where - F: Fn(u8) -> bool, -{ +pub struct Bytes { fallback: F, } @@ -12,12 +9,20 @@ impl Bytes where F: Fn(u8) -> bool, { - pub /* const */ fn new(fallback: F) -> Self { + pub fn new(fallback: F) -> Self { Bytes { fallback } } pub fn find(&self, haystack: &[u8]) -> Option { - haystack.iter().cloned().position(&self.fallback) + haystack.iter().copied().position(&self.fallback) + } + + pub fn iter<'a>(&'a self, haystack: &'a [u8]) -> BytesIter<'a, F> { + BytesIter { + bytes: self, + haystack, + offset: 0, + } } } @@ -26,7 +31,7 @@ pub struct ByteSubstring<'a> { } impl<'a> ByteSubstring<'a> { - pub /* const */ fn new(needle: &'a[u8]) -> Self { + pub fn new(needle: &'a[u8]) -> Self { ByteSubstring { needle } } @@ -41,3 +46,32 @@ impl<'a> ByteSubstring<'a> { .position(|window| window == self.needle) } } + +pub struct BytesIter<'a, F> { + bytes: &'a Bytes, + haystack: &'a [u8], + offset: usize, +} +impl<'a, F> Iterator for BytesIter<'a, F> +where + F: Fn(u8) -> bool, +{ + type Item = usize; + + fn next(&mut self) -> Option { + let idx = self.bytes.find(self.haystack); + if let Some(idx) = idx { + self.haystack = &self.haystack[idx + 1..]; + let result = self.offset + idx; + self.offset = result + 1; + Some(result) + } else { + self.haystack = &[]; + None + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.haystack.len())) + } +} diff --git a/src/lib.rs b/src/lib.rs index 7c30098..535dab4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "pattern", feature(pattern))] -#![cfg_attr(feature = "benchmarks", feature(test))] //! A tiny library to efficiently search strings for sets of ASCII //! characters or byte slices for sets of bytes. @@ -9,8 +8,7 @@ //! ### Searching for a set of ASCII characters //! //! ```rust -//! #[macro_use] -//! extern crate jetscii; +//! use jetscii::ascii_chars; //! //! fn main() { //! let part_number = "86-J52:rev1"; @@ -22,8 +20,7 @@ //! ### Searching for a set of bytes //! //! ```rust -//! #[macro_use] -//! extern crate jetscii; +//! use jetscii::bytes; //! //! fn main() { //! let raw_data = [0x00, 0x01, 0x10, 0xFF, 0x42]; @@ -61,8 +58,7 @@ //! //! ``` //! # #[cfg(feature = "pattern")] -//! #[macro_use] -//! extern crate jetscii; +//! use jetscii::ascii_chars; //! //! fn main() { //! # #[cfg(feature = "pattern")] { @@ -141,67 +137,86 @@ //! | **Substring::new("xyzzy").find(s)** | **11475 MB/s** | //! | s.find("xyzzy") | 5391 MB/s | -#[cfg(test)] -#[macro_use] -extern crate lazy_static; -#[cfg(test)] -extern crate memmap; -#[cfg(test)] -extern crate proptest; -#[cfg(test)] -extern crate region; - use std::marker::PhantomData; include!(concat!(env!("OUT_DIR"), "/src/macros.rs")); -#[cfg(any(jetscii_sse4_2 = "yes", jetscii_sse4_2 = "maybe"))] mod simd; -#[cfg(any(jetscii_sse4_2 = "maybe", jetscii_sse4_2 = "no"))] +// This module may not be used if e.g. we statically know we have the sse4.2 target feature +#[allow(unused)] mod fallback; #[cfg(feature = "pattern")] mod pattern; macro_rules! dispatch { - (simd: $simd:expr,fallback: $fallback:expr,) => { - // If we can tell at compile time that we have support, - // call the optimized code directly. - #[cfg(jetscii_sse4_2 = "yes")] + (x86: $x86:expr, aarch64: $aarch64:expr, fallback: $fallback:expr,) => ({ + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + dispatch!( + target_feature: "sse4.2", + is_feature_enabled: is_x86_feature_detected, + simd: $x86, + fallback: $fallback, + ) + } + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] { - $simd + dispatch!( + target_feature: "neon", + is_feature_enabled: is_aarch64_feature_detected, + simd: $aarch64, + fallback: $fallback, + ) } // If we can tell at compile time that we will *never* have // support, call the fallback directly. - #[cfg(jetscii_sse4_2 = "no")] + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm64ec")))] { $fallback } - - // Otherwise, we will be run on a machine with or without - // support, so we perform runtime detection. - #[cfg(jetscii_sse4_2 = "maybe")] - { - if is_x86_feature_detected!("sse4.2") { + }); + (target_feature: $target_feature:tt, is_feature_enabled: $is_feature_enabled:ident, simd: $simd:expr, fallback: $fallback:expr,) => ({ + // If we can tell at compile time that we have support, + // call the optimized code directly. + #[cfg(target_feature = $target_feature)] + { $simd - } else { - $fallback } - } - }; + // Otherwise, we will be run on a machine with or without + // support, so we perform runtime detection. + #[cfg(not(target_feature = $target_feature))] + { + if std::arch::$is_feature_enabled!($target_feature) { + $simd + } else { + $fallback + } + } + }); } /// Searches a slice for a set of bytes. Up to 16 bytes may be used. -pub struct Bytes -where - F: Fn(u8) -> bool, -{ - #[cfg(any(jetscii_sse4_2 = "yes", jetscii_sse4_2 = "maybe"))] - simd: simd::Bytes, - - #[cfg(any(jetscii_sse4_2 = "maybe", jetscii_sse4_2 = "no"))] +pub struct Bytes { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + x86: simd::x86::Bytes, + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + aarch64: simd::aarch64::Bytes, + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] fallback: fallback::Bytes, // Since we might not use the fallback implementation, we add this @@ -220,12 +235,25 @@ where /// intrinsics are not available. The closure **must** search for /// the same bytes as in the array. #[allow(unused_variables)] - pub /* const */ fn new(bytes: [u8; 16], len: i32, fallback: F) -> Self { + #[must_use] + pub fn new(bytes: [u8; 16], len: i32, fallback: F) -> Self { Bytes { - #[cfg(any(jetscii_sse4_2 = "yes", jetscii_sse4_2 = "maybe"))] - simd: simd::Bytes::new(bytes, len), - - #[cfg(any(jetscii_sse4_2 = "maybe", jetscii_sse4_2 = "no"))] + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + x86: simd::x86::Bytes::new(bytes, len), + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + aarch64: simd::aarch64::Bytes::new(bytes, len), + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] fallback: fallback::Bytes::new(fallback), _fallback: PhantomData, @@ -234,22 +262,142 @@ where /// Searches the slice for the first matching byte in the set. #[inline] + #[must_use] pub fn find(&self, haystack: &[u8]) -> Option { dispatch! { - simd: unsafe { self.simd.find(haystack) }, + x86: unsafe { self.x86.find(haystack) }, + aarch64: unsafe { self.aarch64.find(haystack) }, fallback: self.fallback.find(haystack), } } + + pub fn iter<'a>(&'a self, haystack: &'a [u8]) -> BytesIter<'a, F> { + dispatch! { + x86: BytesIter::X86(self.x86.iter(haystack)), + aarch64: BytesIter::Aarch64(self.aarch64.iter(haystack)), + fallback: BytesIter::Fallback(self.fallback.iter(haystack)), + } + } } /// A convenience type that can be used in a constant or static. pub type BytesConst = Bytes bool>; +pub enum BytesIter<'a, F> { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + X86(simd::x86::BytesIter<'a>), + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + Aarch64(simd::aarch64::BytesIter<'a>), + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] + Fallback(fallback::BytesIter<'a, F>), + // Since we might not use the fallback implementation, we add this + // to avoid unused type parameters. + // This type is impossible to construct, but we still use the F parameter + _Impossible(std::convert::Infallible, PhantomData), +} + +impl<'a, F> Iterator for BytesIter<'a, F> +where + F: Fn(u8) -> bool, +{ + type Item = usize; + + fn next(&mut self) -> Option { + match self { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Self::X86(iter) => iter.next(), + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + Self::Aarch64(iter) => iter.next(), + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] + Self::Fallback(iter) => iter.next(), + // Since we might not use the fallback implementation, we add this + // to avoid unused type parameters. + // This type is impossible to construct, but we still use the F parameter + &mut Self::_Impossible(infailable, _) => match infailable {}, + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Self::X86(iter) => iter.size_hint(), + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + Self::Aarch64(iter) => iter.size_hint(), + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] + Self::Fallback(iter) => iter.size_hint(), + // Since we might not use the fallback implementation, we add this + // to avoid unused type parameters. + // This type is impossible to construct, but we still use the F parameter + &Self::_Impossible(infailable, _) => match infailable {}, + } + } + + fn count(self) -> usize + where + Self: Sized, + { + match self { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Self::X86(iter) => iter.count(), + + #[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] + Self::Aarch64(iter) => iter.count(), + + #[cfg(not(any( + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse4.2" + ), + all( + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" + ), + )))] + Self::Fallback(iter) => iter.count(), + // Since we might not use the fallback implementation, we add this + // to avoid unused type parameters. + // This type is impossible to construct, but we still use the F parameter + Self::_Impossible(infailable, _) => match infailable {}, + } + } +} + /// Searches a string for a set of ASCII characters. Up to 16 /// characters may be used. -pub struct AsciiChars(Bytes) -where - F: Fn(u8) -> bool; +pub struct AsciiChars(Bytes); impl AsciiChars where @@ -265,56 +413,106 @@ where /// ### Panics /// /// - If you provide a non-ASCII byte. - pub /* const */ fn new(chars: [u8; 16], len: i32, fallback: F) -> Self { + #[must_use] + pub fn new(chars: [u8; 16], len: i32, fallback: F) -> Self { for &b in &chars { assert!(b < 128, "Cannot have non-ASCII bytes"); } - AsciiChars(Bytes::new(chars, len, fallback)) + Self(Bytes::new(chars, len, fallback)) } /// Searches the string for the first matching ASCII byte in the set. #[inline] + #[must_use] pub fn find(&self, haystack: &str) -> Option { self.0.find(haystack.as_bytes()) } + + /// Return an iterator over the indices of the specified characters in the haystack. + pub fn iter<'a>(&'a self, haystack: &'a str) -> AsciiCharsIter<'a, F> { + AsciiCharsIter(self.0.iter(haystack.as_bytes())) + } + + /// Get a [`Bytes`] reference with the same ascii characters + pub fn as_bytes(&self) -> &Bytes { + &self.0 + } } /// A convenience type that can be used in a constant or static. pub type AsciiCharsConst = AsciiChars bool>; +pub struct AsciiCharsIter<'a, F>(BytesIter<'a, F>); + +impl Iterator for AsciiCharsIter<'_, F> +where + F: Fn(u8) -> bool, +{ + type Item = usize; + + fn next(&mut self) -> Option { + self.0.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } + + fn count(self) -> usize { + self.0.count() + } +} + /// Searches a slice for the first occurence of the subslice. pub struct ByteSubstring<'a> { - #[cfg(any(jetscii_sse4_2 = "yes", jetscii_sse4_2 = "maybe"))] - simd: simd::ByteSubstring<'a>, - - #[cfg(any(jetscii_sse4_2 = "maybe", jetscii_sse4_2 = "no"))] + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + x86: simd::x86::ByteSubstring<'a>, + + #[cfg(not( + any( + all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse4.2"), + // No aarch64 simd implementation of substring search + // all(any(target_arch = "aarch64", target_arch = "arm64ec"), target_feature = "neon"), + ) + ))] fallback: fallback::ByteSubstring<'a>, } impl<'a> ByteSubstring<'a> { - pub /* const */ fn new(needle: &'a [u8]) -> Self { - ByteSubstring { - #[cfg(any(jetscii_sse4_2 = "yes", jetscii_sse4_2 = "maybe"))] - simd: simd::ByteSubstring::new(needle), - - #[cfg(any(jetscii_sse4_2 = "maybe", jetscii_sse4_2 = "no"))] + #[must_use] + pub fn new(needle: &'a [u8]) -> Self { + Self { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + x86: simd::x86::ByteSubstring::new(needle), + + #[cfg(not( + any( + all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse4.2"), + // No aarch64 simd implementation of substring search + // all(any(target_arch = "aarch64", target_arch = "arm64ec"), target_feature = "neon"), + ) + ))] fallback: fallback::ByteSubstring::new(needle), } } #[cfg(feature = "pattern")] + #[must_use] fn needle_len(&self) -> usize { dispatch! { - simd: self.simd.needle_len(), + x86: self.x86.needle_len(), + aarch64: self.fallback.needle_len(), fallback: self.fallback.needle_len(), } } /// Searches the slice for the first occurence of the subslice. #[inline] + #[must_use] pub fn find(&self, haystack: &[u8]) -> Option { dispatch! { - simd: unsafe { self.simd.find(haystack) }, + x86: unsafe { self.x86.find(haystack) }, + aarch64: self.fallback.find(haystack), fallback: self.fallback.find(haystack), } } @@ -327,17 +525,20 @@ pub type ByteSubstringConst = ByteSubstring<'static>; pub struct Substring<'a>(ByteSubstring<'a>); impl<'a> Substring<'a> { - pub /* const */ fn new(needle: &'a str) -> Self { + #[must_use] + pub fn new(needle: &'a str) -> Self { Substring(ByteSubstring::new(needle.as_bytes())) } #[cfg(feature = "pattern")] + #[must_use] fn needle_len(&self) -> usize { self.0.needle_len() } /// Searches the string for the first occurence of the substring. #[inline] + #[must_use] pub fn find(&self, haystack: &str) -> Option { self.0.find(haystack.as_bytes()) } @@ -345,158 +546,3 @@ impl<'a> Substring<'a> { /// A convenience type that can be used in a constant or static. pub type SubstringConst = Substring<'static>; - -#[cfg(all(test, feature = "benchmarks"))] -mod bench { - extern crate test; - - use super::*; - - lazy_static! { - static ref SPACE: AsciiCharsConst = ascii_chars!(' '); - static ref XML_DELIM_3: AsciiCharsConst = ascii_chars!('<', '>', '&'); - static ref XML_DELIM_5: AsciiCharsConst = ascii_chars!('<', '>', '&', '\'', '"'); - } - - fn prefix_string() -> String { - "a".repeat(5 * 1024 * 1024) - } - - fn bench_space(b: &mut test::Bencher, f: F) - where - F: Fn(&str) -> Option, - { - let mut haystack = prefix_string(); - haystack.push(' '); - - b.iter(|| test::black_box(f(&haystack))); - b.bytes = haystack.len() as u64; - } - - #[bench] - fn space_ascii_chars(b: &mut test::Bencher) { - bench_space(b, |hs| SPACE.find(hs)) - } - - #[bench] - fn space_stdlib_find_string(b: &mut test::Bencher) { - bench_space(b, |hs| hs.find(" ")) - } - - #[bench] - fn space_stdlib_find_char(b: &mut test::Bencher) { - bench_space(b, |hs| hs.find(' ')) - } - - #[bench] - fn space_stdlib_find_char_set(b: &mut test::Bencher) { - bench_space(b, |hs| hs.find(&[' '][..])) - } - - #[bench] - fn space_stdlib_find_closure(b: &mut test::Bencher) { - bench_space(b, |hs| hs.find(|c| c == ' ')) - } - - #[bench] - fn space_stdlib_iterator_position(b: &mut test::Bencher) { - bench_space(b, |hs| hs.as_bytes().iter().position(|&v| v == b' ')) - } - - fn bench_xml_delim_3(b: &mut test::Bencher, f: F) - where - F: Fn(&str) -> Option, - { - let mut haystack = prefix_string(); - haystack.push('&'); - - b.iter(|| test::black_box(f(&haystack))); - b.bytes = haystack.len() as u64; - } - - #[bench] - fn xml_delim_3_ascii_chars(b: &mut test::Bencher) { - bench_xml_delim_3(b, |hs| XML_DELIM_3.find(hs)) - } - - #[bench] - fn xml_delim_3_stdlib_find_char_set(b: &mut test::Bencher) { - bench_xml_delim_3(b, |hs| hs.find(&['<', '>', '&'][..])) - } - - #[bench] - fn xml_delim_3_stdlib_find_char_closure(b: &mut test::Bencher) { - bench_xml_delim_3(b, |hs| hs.find(|c| c == '<' || c == '>' || c == '&')) - } - - #[bench] - fn xml_delim_3_stdlib_iterator_position(b: &mut test::Bencher) { - bench_xml_delim_3(b, |hs| { - hs.as_bytes() - .iter() - .position(|&c| c == b'<' || c == b'>' || c == b'&') - }) - } - - fn bench_xml_delim_5(b: &mut test::Bencher, f: F) - where - F: Fn(&str) -> Option, - { - let mut haystack = prefix_string(); - haystack.push('"'); - - b.iter(|| test::black_box(f(&haystack))); - b.bytes = haystack.len() as u64; - } - - #[bench] - fn xml_delim_5_ascii_chars(b: &mut test::Bencher) { - bench_xml_delim_5(b, |hs| XML_DELIM_5.find(hs)) - } - - #[bench] - fn xml_delim_5_stdlib_find_char_set(b: &mut test::Bencher) { - bench_xml_delim_5(b, |hs| hs.find(&['<', '>', '&', '\'', '"'][..])) - } - - #[bench] - fn xml_delim_5_stdlib_find_char_closure(b: &mut test::Bencher) { - bench_xml_delim_5(b, |hs| { - hs.find(|c| c == '<' || c == '>' || c == '&' || c == '\'' || c == '"') - }) - } - - #[bench] - fn xml_delim_5_stdlib_iterator_position(b: &mut test::Bencher) { - bench_xml_delim_3(b, |hs| { - hs.as_bytes() - .iter() - .position(|&c| c == b'<' || c == b'>' || c == b'&' || c == b'\'' || c == b'"') - }) - } - - lazy_static! { - static ref XYZZY: Substring<'static> = Substring::new("xyzzy"); - } - - fn bench_substring(b: &mut test::Bencher, f: F) - where - F: Fn(&str) -> Option, - { - let mut haystack = prefix_string(); - haystack.push_str("xyzzy"); - - b.iter(|| test::black_box(f(&haystack))); - b.bytes = haystack.len() as u64; - } - - #[bench] - fn substring_with_created_searcher(b: &mut test::Bencher) { - bench_substring(b, |hs| XYZZY.find(hs)) - } - - #[bench] - fn substring_stdlib_find(b: &mut test::Bencher) { - bench_substring(b, |hs| hs.find("xyzzy")) - } -} diff --git a/src/simd.rs b/src/simd.rs deleted file mode 100644 index b03a7c0..0000000 --- a/src/simd.rs +++ /dev/null @@ -1,750 +0,0 @@ -// # Warning -// -// Everything in this module assumes that the SSE 4.2 feature is available. - -use std::{cmp::min, slice}; - -#[cfg(target_arch = "x86")] -use std::arch::x86 as target_arch; -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64 as target_arch; - -use self::target_arch::{ - __m128i, _mm_cmpestri, _mm_cmpestrm, _mm_extract_epi16, _mm_loadu_si128, - _SIDD_CMP_EQUAL_ORDERED, -}; - -include!(concat!(env!("OUT_DIR"), "/src/simd_macros.rs")); - -const BYTES_PER_OPERATION: usize = 16; - -union TransmuteToSimd { - simd: __m128i, - bytes: [u8; 16], -} - -trait PackedCompareControl { - fn needle(&self) -> __m128i; - fn needle_len(&self) -> i32; -} - -#[inline] -#[target_feature(enable = "sse4.2")] -unsafe fn find_small(packed: PackedCompare, haystack: &[u8]) -> Option -where - C: PackedCompareControl, -{ - let mut tail = [0u8; 16]; - core::ptr::copy_nonoverlapping(haystack.as_ptr(), tail.as_mut_ptr(), haystack.len()); - let haystack = &tail[..haystack.len()]; - debug_assert!(haystack.len() < ::std::i32::MAX as usize); - packed.cmpestri(haystack.as_ptr(), haystack.len() as i32) -} - -/// The PCMPxSTRx instructions always read 16 bytes worth of -/// data. Although the instructions handle unaligned memory access -/// just fine, they might attempt to read off the end of a page -/// and into a protected area. -/// -/// To handle this case, we read in 16-byte aligned chunks with -/// respect to the *end* of the byte slice. This makes the -/// complicated part in searching the leftover bytes at the -/// beginning of the byte slice. -#[inline] -#[target_feature(enable = "sse4.2")] -unsafe fn find(packed: PackedCompare, mut haystack: &[u8]) -> Option -where - C: PackedCompareControl, -{ - // FIXME: EXPLAIN SAFETY - - if haystack.is_empty() { - return None; - } - - if haystack.len() < 16 { - return find_small(packed, haystack); - } - - let mut offset = 0; - - if let Some(misaligned) = Misalignment::new(haystack) { - if let Some(location) = packed.cmpestrm(misaligned.leading, misaligned.leading_junk) { - // Since the masking operation covers an entire - // 16-byte chunk, we have to see if the match occurred - // somewhere *after* our data - if location < haystack.len() { - return Some(location); - } - } - - haystack = &haystack[misaligned.bytes_until_alignment..]; - offset += misaligned.bytes_until_alignment; - } - - // TODO: try removing the 16-byte loop and check the disasm - let n_complete_chunks = haystack.len() / BYTES_PER_OPERATION; - - // Getting the pointer once before the loop avoids the - // overhead of manipulating the length of the slice inside the - // loop. - let mut haystack_ptr = haystack.as_ptr(); - let mut chunk_offset = 0; - for _ in 0..n_complete_chunks { - if let Some(location) = packed.cmpestri(haystack_ptr, BYTES_PER_OPERATION as i32) { - return Some(offset + chunk_offset + location); - } - - haystack_ptr = haystack_ptr.offset(BYTES_PER_OPERATION as isize); - chunk_offset += BYTES_PER_OPERATION; - } - haystack = &haystack[chunk_offset..]; - offset += chunk_offset; - - // No data left to search - if haystack.is_empty() { - return None; - } - - find_small(packed, haystack).map(|loc| loc + offset) -} - -struct PackedCompare(T); -impl PackedCompare -where - T: PackedCompareControl, -{ - #[inline] - #[target_feature(enable = "sse4.2")] - unsafe fn cmpestrm(&self, haystack: &[u8], leading_junk: usize) -> Option { - // TODO: document why this is ok - let haystack = _mm_loadu_si128(haystack.as_ptr() as *const __m128i); - - let mask = _mm_cmpestrm( - self.0.needle(), - self.0.needle_len(), - haystack, - BYTES_PER_OPERATION as i32, - CONTROL_BYTE, - ); - let mask = _mm_extract_epi16(mask, 0) as u16; - - if mask.trailing_zeros() < 16 { - let mut mask = mask; - // Byte: 7 6 5 4 3 2 1 0 - // Str : &[0, 1, 2, 3, ...] - // - // Bit-0 corresponds to Str-0; shifting to the right - // removes the parts of the string that don't belong to - // us. - mask >>= leading_junk; - // The first 1, starting from Bit-0 and going to Bit-7, - // denotes the position of the first match. - if mask == 0 { - // All of our matches were before the slice started - None - } else { - let first_match = mask.trailing_zeros() as usize; - debug_assert!(first_match < 16); - Some(first_match) - } - } else { - None - } - } - - #[inline] - #[target_feature(enable = "sse4.2")] - unsafe fn cmpestri(&self, haystack: *const u8, haystack_len: i32) -> Option { - debug_assert!( - (1..=16).contains(&haystack_len), - "haystack_len was {}", - haystack_len, - ); - - // TODO: document why this is ok - let haystack = _mm_loadu_si128(haystack as *const __m128i); - - let location = _mm_cmpestri( - self.0.needle(), - self.0.needle_len(), - haystack, - haystack_len, - CONTROL_BYTE, - ); - - if location < 16 { - Some(location as usize) - } else { - None - } - } -} - -#[derive(Debug)] -struct Misalignment<'a> { - leading: &'a [u8], - leading_junk: usize, - bytes_until_alignment: usize, -} - -impl<'a> Misalignment<'a> { - /// # Cases - /// - /// 0123456789ABCDEF - /// |--| < 1. - /// |--| < 2. - /// |--| < 3. - /// |----| < 4. - /// - /// 1. The input slice is aligned. - /// 2. The input slice is unaligned and is completely within the 16-byte chunk. - /// 3. The input slice is unaligned and touches the boundary of the 16-byte chunk. - /// 4. The input slice is unaligned and crosses the boundary of the 16-byte chunk. - #[inline] - fn new(haystack: &[u8]) -> Option { - let aligned_start = ((haystack.as_ptr() as usize) & !0xF) as *const u8; - - // If we are already aligned, there's nothing to do - if aligned_start == haystack.as_ptr() { - return None; - } - - let aligned_end = unsafe { aligned_start.offset(BYTES_PER_OPERATION as isize) }; - - let leading_junk = haystack.as_ptr() as usize - aligned_start as usize; - let leading_len = min(haystack.len() + leading_junk, BYTES_PER_OPERATION); - - let leading = unsafe { slice::from_raw_parts(aligned_start, leading_len) }; - - let bytes_until_alignment = if leading_len == BYTES_PER_OPERATION { - aligned_end as usize - haystack.as_ptr() as usize - } else { - haystack.len() - }; - - Some(Misalignment { - leading, - leading_junk, - bytes_until_alignment, - }) - } -} - -pub struct Bytes { - needle: __m128i, - needle_len: i32, -} - -impl Bytes { - pub /* const */ fn new(bytes: [u8; 16], needle_len: i32) -> Self { - Bytes { - needle: unsafe { TransmuteToSimd { bytes }.simd }, - needle_len, - } - } - - #[inline] - #[target_feature(enable = "sse4.2")] - pub unsafe fn find(&self, haystack: &[u8]) -> Option { - find(PackedCompare::<_, 0>(self), haystack) - } -} - -impl<'b> PackedCompareControl for &'b Bytes { - fn needle(&self) -> __m128i { - self.needle - } - fn needle_len(&self) -> i32 { - self.needle_len - } -} - -pub struct ByteSubstring<'a> { - complete_needle: &'a [u8], - needle: __m128i, - needle_len: i32, -} - -impl<'a> ByteSubstring<'a> { - pub /* const */ fn new(needle: &'a[u8]) -> Self { - use std::cmp; - - let mut simd_needle = [0; 16]; - let len = cmp::min(simd_needle.len(), needle.len()); - simd_needle[..len].copy_from_slice(&needle[..len]); - ByteSubstring { - complete_needle: needle, - needle: unsafe { TransmuteToSimd { bytes: simd_needle }.simd }, - needle_len: len as i32, - } - } - - #[cfg(feature = "pattern")] - pub fn needle_len(&self) -> usize { - self.complete_needle.len() - } - - #[inline] - #[target_feature(enable = "sse4.2")] - pub unsafe fn find(&self, haystack: &[u8]) -> Option { - let mut offset = 0; - - while let Some(idx) = find(PackedCompare::<_, _SIDD_CMP_EQUAL_ORDERED>(self), &haystack[offset..]) { - let abs_offset = offset + idx; - // Found a match, but is it really? - if haystack[abs_offset..].starts_with(self.complete_needle) { - return Some(abs_offset); - } - - // Skip past this false positive - offset += idx + 1; - } - - None - } -} - -impl<'a, 'b> PackedCompareControl for &'b ByteSubstring<'a> { - fn needle(&self) -> __m128i { - self.needle - } - fn needle_len(&self) -> i32 { - self.needle_len - } -} - -#[cfg(test)] -mod test { - use proptest::prelude::*; - use std::{fmt, str}; - use memmap::MmapMut; - use region::Protection; - - use super::*; - - lazy_static! { - static ref SPACE: Bytes = simd_bytes!(b' '); - static ref XML_DELIM_3: Bytes = simd_bytes!(b'<', b'>', b'&'); - static ref XML_DELIM_5: Bytes = simd_bytes!(b'<', b'>', b'&', b'\'', b'"'); - } - - trait SliceFindPolyfill { - fn find_any(&self, needles: &[T]) -> Option; - fn find_seq(&self, needle: &[T]) -> Option; - } - - impl SliceFindPolyfill for [T] - where - T: PartialEq, - { - fn find_any(&self, needles: &[T]) -> Option { - self.iter().position(|c| needles.contains(c)) - } - - fn find_seq(&self, needle: &[T]) -> Option { - (0..self.len()).find(|&l| self[l..].starts_with(needle)) - } - } - - struct Haystack { - data: Vec, - start: usize, - } - - impl Haystack { - fn without_start(&self) -> &[u8] { - &self.data - } - - fn with_start(&self) -> &[u8] { - &self.data[self.start..] - } - } - - // Knowing the address of the data can be important - impl fmt::Debug for Haystack { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Haystack") - .field("data", &self.data) - .field("(addr)", &self.data.as_ptr()) - .field("start", &self.start) - .finish() - } - } - - /// Creates a set of bytes and an offset inside them. Allows - /// checking arbitrary memory offsets, not just where the - /// allocator placed a value. - fn haystack() -> BoxedStrategy { - any::>() - .prop_flat_map(|data| { - let len = 0..=data.len(); - (Just(data), len) - }) - .prop_map(|(data, start)| Haystack { data, start }) - .boxed() - } - - #[derive(Debug)] - struct Needle { - data: [u8; 16], - len: usize, - } - - impl Needle { - fn as_slice(&self) -> &[u8] { - &self.data[..self.len] - } - } - - /// Creates an array and the number of valid values - fn needle() -> BoxedStrategy { - (any::<[u8; 16]>(), 0..=16_usize) - .prop_map(|(data, len)| Needle { data, len }) - .boxed() - } - - proptest! { - #[test] - fn works_as_find_does_for_up_to_and_including_16_bytes( - (haystack, needle) in (haystack(), needle()) - ) { - let haystack = haystack.without_start(); - - let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) }; - let them = haystack.find_any(needle.as_slice()); - assert_eq!(us, them); - } - - #[test] - fn works_as_find_does_for_various_memory_offsets( - (needle, haystack) in (needle(), haystack()) - ) { - let haystack = haystack.with_start(); - - let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) }; - let them = haystack.find_any(needle.as_slice()); - assert_eq!(us, them); - } - } - - #[test] - fn can_search_for_null_bytes() { - unsafe { - let null = simd_bytes!(b'\0'); - assert_eq!(Some(1), null.find(b"a\0")); - assert_eq!(Some(0), null.find(b"\0")); - assert_eq!(None, null.find(b"")); - } - } - - #[test] - fn can_search_in_null_bytes() { - unsafe { - let a = simd_bytes!(b'a'); - assert_eq!(Some(1), a.find(b"\0a")); - assert_eq!(None, a.find(b"\0")); - } - } - - #[test] - fn space_is_found() { - unsafe { - // Since the algorithm operates on 16-byte chunks, it's - // important to cover tests around that boundary. Since 16 - // isn't that big of a number, we might as well do all of - // them. - - assert_eq!(Some(0), SPACE.find(b" ")); - assert_eq!(Some(1), SPACE.find(b"0 ")); - assert_eq!(Some(2), SPACE.find(b"01 ")); - assert_eq!(Some(3), SPACE.find(b"012 ")); - assert_eq!(Some(4), SPACE.find(b"0123 ")); - assert_eq!(Some(5), SPACE.find(b"01234 ")); - assert_eq!(Some(6), SPACE.find(b"012345 ")); - assert_eq!(Some(7), SPACE.find(b"0123456 ")); - assert_eq!(Some(8), SPACE.find(b"01234567 ")); - assert_eq!(Some(9), SPACE.find(b"012345678 ")); - assert_eq!(Some(10), SPACE.find(b"0123456789 ")); - assert_eq!(Some(11), SPACE.find(b"0123456789A ")); - assert_eq!(Some(12), SPACE.find(b"0123456789AB ")); - assert_eq!(Some(13), SPACE.find(b"0123456789ABC ")); - assert_eq!(Some(14), SPACE.find(b"0123456789ABCD ")); - assert_eq!(Some(15), SPACE.find(b"0123456789ABCDE ")); - assert_eq!(Some(16), SPACE.find(b"0123456789ABCDEF ")); - assert_eq!(Some(17), SPACE.find(b"0123456789ABCDEFG ")); - } - } - - #[test] - fn space_not_found() { - unsafe { - // Since the algorithm operates on 16-byte chunks, it's - // important to cover tests around that boundary. Since 16 - // isn't that big of a number, we might as well do all of - // them. - - assert_eq!(None, SPACE.find(b"")); - assert_eq!(None, SPACE.find(b"0")); - assert_eq!(None, SPACE.find(b"01")); - assert_eq!(None, SPACE.find(b"012")); - assert_eq!(None, SPACE.find(b"0123")); - assert_eq!(None, SPACE.find(b"01234")); - assert_eq!(None, SPACE.find(b"012345")); - assert_eq!(None, SPACE.find(b"0123456")); - assert_eq!(None, SPACE.find(b"01234567")); - assert_eq!(None, SPACE.find(b"012345678")); - assert_eq!(None, SPACE.find(b"0123456789")); - assert_eq!(None, SPACE.find(b"0123456789A")); - assert_eq!(None, SPACE.find(b"0123456789AB")); - assert_eq!(None, SPACE.find(b"0123456789ABC")); - assert_eq!(None, SPACE.find(b"0123456789ABCD")); - assert_eq!(None, SPACE.find(b"0123456789ABCDE")); - assert_eq!(None, SPACE.find(b"0123456789ABCDEF")); - assert_eq!(None, SPACE.find(b"0123456789ABCDEFG")); - } - } - - #[test] - fn works_on_nonaligned_beginnings() { - unsafe { - // We have special code for strings that don't lie on 16-byte - // boundaries. Since allocation seems to happen on these - // boundaries by default, let's walk around a bit. - - let s = b"0123456789ABCDEF ".to_vec(); - - assert_eq!(Some(16), SPACE.find(&s[0..])); - assert_eq!(Some(15), SPACE.find(&s[1..])); - assert_eq!(Some(14), SPACE.find(&s[2..])); - assert_eq!(Some(13), SPACE.find(&s[3..])); - assert_eq!(Some(12), SPACE.find(&s[4..])); - assert_eq!(Some(11), SPACE.find(&s[5..])); - assert_eq!(Some(10), SPACE.find(&s[6..])); - assert_eq!(Some(9), SPACE.find(&s[7..])); - assert_eq!(Some(8), SPACE.find(&s[8..])); - assert_eq!(Some(7), SPACE.find(&s[9..])); - assert_eq!(Some(6), SPACE.find(&s[10..])); - assert_eq!(Some(5), SPACE.find(&s[11..])); - assert_eq!(Some(4), SPACE.find(&s[12..])); - assert_eq!(Some(3), SPACE.find(&s[13..])); - assert_eq!(Some(2), SPACE.find(&s[14..])); - assert_eq!(Some(1), SPACE.find(&s[15..])); - assert_eq!(Some(0), SPACE.find(&s[16..])); - assert_eq!(None, SPACE.find(&s[17..])); - } - } - - #[test] - fn misalignment_does_not_cause_a_false_positive_before_start() { - const AAAA: u8 = 0x01; - - let needle = Needle { - data: [ - AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ], - len: 1, - }; - let haystack = Haystack { - data: vec![ - AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, - ], - start: 1, - }; - - let haystack = haystack.with_start(); - - // Needs to trigger the misalignment code - assert_ne!(0, (haystack.as_ptr() as usize) % 16); - // There are 64 bits in the mask and we check to make sure the - // result is less than the haystack - assert!(haystack.len() > 64); - - let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) }; - assert_eq!(None, us); - } - - #[test] - fn xml_delim_3_is_found() { - unsafe { - assert_eq!(Some(0), XML_DELIM_3.find(b"<")); - assert_eq!(Some(0), XML_DELIM_3.find(b">")); - assert_eq!(Some(0), XML_DELIM_3.find(b"&")); - assert_eq!(None, XML_DELIM_3.find(b"")); - } - } - - #[test] - fn xml_delim_5_is_found() { - unsafe { - assert_eq!(Some(0), XML_DELIM_5.find(b"<")); - assert_eq!(Some(0), XML_DELIM_5.find(b">")); - assert_eq!(Some(0), XML_DELIM_5.find(b"&")); - assert_eq!(Some(0), XML_DELIM_5.find(b"'")); - assert_eq!(Some(0), XML_DELIM_5.find(b"\"")); - assert_eq!(None, XML_DELIM_5.find(b"")); - } - } - - proptest! { - #[test] - fn works_as_find_does_for_byte_substrings( - (needle, haystack) in (any::>(), any::>()) - ) { - let us = unsafe { - let s = ByteSubstring::new(&needle); - s.find(&haystack) - }; - let them = haystack.find_seq(&needle); - assert_eq!(us, them); - } - } - - #[test] - fn byte_substring_is_found() { - unsafe { - let substr = ByteSubstring::new(b"zz"); - assert_eq!(Some(0), substr.find(b"zz")); - assert_eq!(Some(1), substr.find(b"0zz")); - assert_eq!(Some(2), substr.find(b"01zz")); - assert_eq!(Some(3), substr.find(b"012zz")); - assert_eq!(Some(4), substr.find(b"0123zz")); - assert_eq!(Some(5), substr.find(b"01234zz")); - assert_eq!(Some(6), substr.find(b"012345zz")); - assert_eq!(Some(7), substr.find(b"0123456zz")); - assert_eq!(Some(8), substr.find(b"01234567zz")); - assert_eq!(Some(9), substr.find(b"012345678zz")); - assert_eq!(Some(10), substr.find(b"0123456789zz")); - assert_eq!(Some(11), substr.find(b"0123456789Azz")); - assert_eq!(Some(12), substr.find(b"0123456789ABzz")); - assert_eq!(Some(13), substr.find(b"0123456789ABCzz")); - assert_eq!(Some(14), substr.find(b"0123456789ABCDzz")); - assert_eq!(Some(15), substr.find(b"0123456789ABCDEzz")); - assert_eq!(Some(16), substr.find(b"0123456789ABCDEFzz")); - assert_eq!(Some(17), substr.find(b"0123456789ABCDEFGzz")); - } - } - - #[test] - fn byte_substring_is_not_found() { - unsafe { - let substr = ByteSubstring::new(b"zz"); - assert_eq!(None, substr.find(b"")); - assert_eq!(None, substr.find(b"0")); - assert_eq!(None, substr.find(b"01")); - assert_eq!(None, substr.find(b"012")); - assert_eq!(None, substr.find(b"0123")); - assert_eq!(None, substr.find(b"01234")); - assert_eq!(None, substr.find(b"012345")); - assert_eq!(None, substr.find(b"0123456")); - assert_eq!(None, substr.find(b"01234567")); - assert_eq!(None, substr.find(b"012345678")); - assert_eq!(None, substr.find(b"0123456789")); - assert_eq!(None, substr.find(b"0123456789A")); - assert_eq!(None, substr.find(b"0123456789AB")); - assert_eq!(None, substr.find(b"0123456789ABC")); - assert_eq!(None, substr.find(b"0123456789ABCD")); - assert_eq!(None, substr.find(b"0123456789ABCDE")); - assert_eq!(None, substr.find(b"0123456789ABCDEF")); - assert_eq!(None, substr.find(b"0123456789ABCDEFG")); - } - } - - #[test] - fn byte_substring_has_false_positive() { - unsafe { - // The PCMPESTRI instruction will mark the "a" before "ab" as - // a match because it cannot look beyond the 16 byte window - // of the haystack. We need to double-check any match to - // ensure it completely matches. - - let substr = ByteSubstring::new(b"ab"); - assert_eq!(Some(16), substr.find(b"aaaaaaaaaaaaaaaaab")) - // this "a" is a false positive ~~~~~~~~~~~~~~~^ - }; - } - - #[test] - fn byte_substring_needle_is_longer_than_16_bytes() { - unsafe { - let needle = b"0123456789abcdefg"; - let haystack = b"0123456789abcdefgh"; - assert_eq!(Some(0), ByteSubstring::new(needle).find(haystack)); - } - } - - fn with_guarded_string(value: &str, f: impl FnOnce(&str)) { - // Allocate a string that ends directly before a - // read-protected page. - - let page_size = region::page::size(); - assert!(value.len() <= page_size); - - // Map two rw-accessible pages of anonymous memory - let mut mmap = MmapMut::map_anon(2 * page_size).unwrap(); - - let (first_page, second_page) = mmap.split_at_mut(page_size); - - // Prohibit any access to the second page, so that any attempt - // to read or write it would trigger a segfault - unsafe { - region::protect(second_page.as_ptr(), page_size, Protection::NONE).unwrap(); - } - - // Copy bytes to the end of the first page - let dest = &mut first_page[page_size - value.len()..]; - dest.copy_from_slice(value.as_bytes()); - f(unsafe { str::from_utf8_unchecked(dest) }); - } - - #[test] - fn works_at_page_boundary() { - // PCMPxSTRx instructions are known to read 16 bytes at a - // time. This behaviour may cause accidental segfaults by - // reading past the page boundary. - // - // For now, this test failing crashes the whole test - // suite. This could be fixed by setting a custom signal - // handler, though Rust lacks such facilities at the moment. - - // Allocate a 16-byte string at page boundary. To verify this - // test, set protect=false to prevent segfaults. - with_guarded_string("0123456789abcdef", |text| { - // Will search for the last char - let needle = simd_bytes!(b'f'); - - // Check all suffixes of our 16-byte string - for offset in 0..text.len() { - let tail = &text[offset..]; - unsafe { - assert_eq!(Some(tail.len() - 1), needle.find(tail.as_bytes())); - } - } - }); - } - - #[test] - fn does_not_access_memory_after_haystack_when_haystack_is_multiple_of_16_bytes_and_no_match() { - // For now, this test failing crashes the whole test - // suite. This could be fixed by setting a custom signal - // handler, though Rust lacks such facilities at the moment. - with_guarded_string("0123456789abcdef", |text| { - // Will search for a char not present - let needle = simd_bytes!(b'z'); - - unsafe { - assert_eq!(None, needle.find(text.as_bytes())); - } - }); - } -} diff --git a/src/simd/aarch64.rs b/src/simd/aarch64.rs new file mode 100644 index 0000000..7bdce74 --- /dev/null +++ b/src/simd/aarch64.rs @@ -0,0 +1,261 @@ +//! AArch64 NEON SIMD implementation +//! +//! Based on [this algorithm][1] simplified somewhat by aarch64 neon (by the ability to make a +//! table of 32 bytes by combining two vectors), combined with 64 byte movemask via interleaving +//! from [here][2], and the iteration idea (using a u64 bitset of known matches) from [here][3]. +//! +//! [1]: http://0x80.pl/notesen/2018-10-18-simd-byte-lookup.html +//! [2]: https://community.arm.com/arm-community-blogs/b/servers-and-cloud-computing-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon +//! [3]: https://lemire.me/blog/2024/07/20/scan-html-even-faster-with-simd-instructions-c-and-c/ + +use std::arch::aarch64::*; +use std::mem::transmute; + +#[derive(Copy, Clone)] +pub struct Bytes { + bitset: uint8x16x2_t, +} + +type Vector = uint8x16_t; +type Chunk = uint8x16x4_t; + +/// Mapping from a number `i` in 0..=7 to a bit mask with the `i`-th bit set. +const N_TO_N_BITS_TABLE: uint8x16_t = unsafe { + let mut bits = [0u8; 16]; + let mut i = 0u8; + while i < 8 { + bits[i as usize] = 1 << i; + i += 1; + } + transmute(bits) +}; + +impl Bytes { + pub fn new(bytes: [u8; 16], needle_len: i32) -> Self { + assert!((0..=16).contains(&needle_len)); + let needle_len = needle_len as u8; + + // Make a bitset from the bytes to search for + let mut bitset = [0u8; 256 / 8]; + for i in 0..needle_len { + let i = usize::from(i); + let value = usize::from(bytes[i]); + let byte = value / 8; + let bit = value % 8; + let mask = 1 << bit; + bitset[byte] |= mask; + } + let bitset = unsafe { transmute(bitset) }; + + Bytes { bitset } + } + + #[target_feature(enable = "neon")] + pub fn find(&self, haystack: &[u8]) -> Option { + let mut vectors = haystack.chunks_exact(size_of::()); + + let mut offset = 0; + for vector in vectors.by_ref() { + let vector = unsafe { vld1q_u8(vector.as_ptr()) }; + let result = self.locale_in_vector(vector); + if let Some(element) = first_element_set(result) { + return Some(offset + usize::from(element)); + } + offset += size_of_val(&vector); + } + let remaining = vectors.remainder(); + let mut fake_vector = [0; size_of::()]; + fake_vector[..remaining.len()].copy_from_slice(remaining); + let vector = unsafe { vld1q_u8(fake_vector.as_ptr()) }; + let result = self.locale_in_vector(vector); + if let Some(element) = first_element_set(result) { + if element < remaining.len() as u8 { + return Some(offset + usize::from(element)); + } + } + + None + } + + pub fn iter<'a>(self, haystack: &'a [u8]) -> BytesIter<'a> { + BytesIter::new(self, haystack) + } + + /// Given a vector of 16 bytes of input, return a vector of true/false values. + /// + /// Each element in the output will be 0/255 based on if the input byte at that position + /// is equal to one of the values being searched for. + #[inline] + #[target_feature(enable = "neon")] + fn locale_in_vector(&self, v: Vector) -> Vector { + let high_bits = vshrq_n_u8::<3>(v); + let low_bit_masks = vqtbl2q_u8(self.bitset, high_bits); + let low_bits = vandq_u8(v, vdupq_n_u8(0b0111)); + let low_bits = vqtbl1q_u8(N_TO_N_BITS_TABLE, low_bits); + vtstq_u8(low_bits, low_bit_masks) + } + + /// Returns a u64 where each set bit indicates a match in the chunk. + /// + /// Takes a "deinterleaved" chunk of 4 vectors, each with 16 bytes. + /// The first element of the input is the first element of the first vector, + /// the second element of the input is the first element of the second vector, + /// the fifth element of the input is the second element of the first vector, etc. + #[inline] + #[target_feature(enable = "neon")] + fn locate_in_chunk(&self, chunk: Chunk) -> u64 { + let chunk_values = [chunk.0, chunk.1, chunk.2, chunk.3]; + // Get 4 "bool" vectors indicating if each element in the chunk is a match + let matching_elements = chunk_values.map(|v| self.locale_in_vector(v)); + + // Pack bits from the 4 vectors into a single vector + + // shift the second vector right by one, insert the top bit from the first vector + // The top two bits each element of temp0 are from the first and second vector + let temp0 = vsriq_n_u8::<1>(matching_elements[1], matching_elements[0]); + + // shift the fourth vector right by one, insert the top bit from the third vector + // The top two bits each element of temp1 are from the third and fourth vector + let temp1 = vsriq_n_u8::<1>(matching_elements[3], matching_elements[2]); + + // shift temp1 (the top two bits of which are from the third and fourth vector) right by 2, + // insert the top two bits from temp0 (the top two bits of which are from the first and + // second vector) + // The top four bits of each element of temp2 are from the first, second, third, and fourth + // vector + let temp2 = vsriq_n_u8::<2>(temp1, temp0); + + // duplicate the top 4 bits into the bottom 4 bits of each element + let temp3 = vsriq_n_u8::<4>(temp2, temp2); + + // The top/bottom 4 bits of each element are the same, so converting to a 64 bit bitset + // takes those 4 bits from each element and places them next to each other + vector_to_bitset(temp3) + } +} + +pub struct BytesIter<'a> { + /// The bytes to search for + bytes: Bytes, + /// The remaining haystack (after the current bitset chunk) + haystack: &'a [u8], + /// The current offset (from the start of the original haystack) + offset: usize, + /// A reversed bitset of the the current chunk + /// + /// e.g. the most significant bit is set if the next byte matches one of the searched bytes + current_bitset: u64, +} + +impl<'a> BytesIter<'a> { + fn new(bytes: Bytes, haystack: &'a [u8]) -> Self { + Self { + bytes, + haystack, + offset: 0, + current_bitset: 0, + } + } + + #[target_feature(enable = "neon")] + fn fill_bitset(&mut self) { + while let Some((chunk, rest)) = self.haystack.split_at_checked(size_of::()) { + self.haystack = rest; + let chunk = unsafe { vld4q_u8(chunk.as_ptr()) }; + let bitset = self.bytes.locate_in_chunk(chunk); + if bitset != 0 { + // aarch64 doesn't have a count trailing zeros instruction, so + // reverse the bits so we use leading_zeros instead + self.current_bitset = bitset.reverse_bits(); + return; + } + self.offset += size_of_val(&chunk); + } + let mut fake_chunk = [0; size_of::()]; + fake_chunk[..self.haystack.len()].copy_from_slice(self.haystack); + let chunk = unsafe { vld4q_u8(fake_chunk.as_ptr()) }; + self.current_bitset = self.bytes.locate_in_chunk(chunk); + let mask = !(u64::MAX << self.haystack.len() as u64); + self.current_bitset &= mask; + // aarch64 doesn't have a count trailing zeros instruction, so + // reverse the bits so we use leading_zeros instead + self.current_bitset = self.current_bitset.reverse_bits(); + self.haystack = &[]; + } +} + +impl<'a> Iterator for BytesIter<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let mut first_bit = self.current_bitset.leading_zeros(); + if first_bit == 64 { + unsafe { + self.fill_bitset(); + } + first_bit = self.current_bitset.leading_zeros(); + if first_bit == 64 { + return None; + } + } + // toggle the highest bit + self.current_bitset ^= 1 << (63 - first_bit); + let result = self.offset + first_bit as usize; + if self.current_bitset == 0 { + self.offset += 64; + } + Some(result) + } + + fn size_hint(&self) -> (usize, Option) { + let min = self.current_bitset.count_ones() as usize; + let max = min.checked_add(self.haystack.len()); + (min, max) + } + + // We can be a little faster by avoiding iterating through the bits by counting bits directly + fn count(self) -> usize { + let mut count = self.current_bitset.count_ones() as usize; + + let mut chunks = self.haystack.chunks_exact(size_of::()); + for chunk in chunks.by_ref() { + let chunk = unsafe { vld4q_u8(chunk.as_ptr()) }; + let result = unsafe { self.bytes.locate_in_chunk(chunk) }; + count += result.count_ones() as usize; + } + let remaining = chunks.remainder(); + let mut fake_chunk = [0; size_of::()]; + fake_chunk[..remaining.len()].copy_from_slice(remaining); + let chunk = unsafe { vld4q_u8(fake_chunk.as_ptr()) }; + let result = unsafe { self.bytes.locate_in_chunk(chunk) }; + let mask = !(u64::MAX << self.haystack.len() as u64); + count += (result & mask).count_ones() as usize; + count + } +} + +/// Convert a vector into a u64 +/// +/// Returns a value where the first 4 bits are the low 4 bits of the first element, +/// the next 4 bits are the high 4 bits of the second element, and so on. +/// +/// For a "bool" vector (where every element is either 255 or 0), the resulting u64 +/// will have groups of 4 bits from each element. e.g. the number of trailing zero bits +/// is 4 times the number of trailing zero elements. +#[target_feature(enable = "neon")] +fn vector_to_bitset(vector: Vector) -> u64 { + let result_vector = vshrn_n_u16::<4>(vreinterpretq_u16_u8(vector)); + vget_lane_u64::<0>(vreinterpret_u64_u8(result_vector)) +} + +/// Find the first element in a "bool" vector that is set or None if all elements are zero +#[target_feature(enable = "neon")] +fn first_element_set(vector: Vector) -> Option { + let bitset = vector_to_bitset(vector); + let first_bit = bitset.trailing_zeros() / 4; + if first_bit < 16 { + Some(first_bit as u8) + } else { + None + } +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs new file mode 100644 index 0000000..6359edd --- /dev/null +++ b/src/simd/mod.rs @@ -0,0 +1,4 @@ +#[cfg(any(target_arch = "aarch64", target_arch = "arm64ec"))] +pub mod aarch64; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub mod x86; diff --git a/src/simd/x86.rs b/src/simd/x86.rs new file mode 100644 index 0000000..2ff4fc9 --- /dev/null +++ b/src/simd/x86.rs @@ -0,0 +1,371 @@ +// # Warning +// +// Everything in this module assumes that the SSE 4.2 feature is available. + +use std::{cmp::min, slice}; + +#[cfg(target_arch = "x86")] +use std::arch::x86 as target_arch; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64 as target_arch; + +use self::target_arch::{ + __m128i, _mm_cmpestri, _mm_cmpestrm, _mm_extract_epi16, _mm_loadu_si128, + _SIDD_CMP_EQUAL_ORDERED, +}; + +include!(concat!(env!("OUT_DIR"), "/src/simd_macros.rs")); + +const BYTES_PER_OPERATION: usize = 16; + +union TransmuteToSimd { + simd: __m128i, + bytes: [u8; 16], +} + +trait PackedCompareControl { + fn needle(&self) -> __m128i; + fn needle_len(&self) -> i32; +} + +#[inline] +#[target_feature(enable = "sse4.2")] +unsafe fn find_small( + packed: PackedCompare, + haystack: &[u8], +) -> Option +where + C: PackedCompareControl, +{ + let mut tail = [0u8; 16]; + core::ptr::copy_nonoverlapping(haystack.as_ptr(), tail.as_mut_ptr(), haystack.len()); + let haystack = &tail[..haystack.len()]; + debug_assert!(haystack.len() < ::std::i32::MAX as usize); + packed.cmpestri(haystack.as_ptr(), haystack.len() as i32) +} + +/// The `PCMPxSTRx` instructions always read 16 bytes worth of +/// data. Although the instructions handle unaligned memory access +/// just fine, they might attempt to read off the end of a page +/// and into a protected area. +/// +/// To handle this case, we read in 16-byte aligned chunks with +/// respect to the *end* of the byte slice. This makes the +/// complicated part in searching the leftover bytes at the +/// beginning of the byte slice. +#[inline] +#[target_feature(enable = "sse4.2")] +unsafe fn find( + packed: PackedCompare, + mut haystack: &[u8], +) -> Option +where + C: PackedCompareControl, +{ + // FIXME: EXPLAIN SAFETY + + if haystack.is_empty() { + return None; + } + + if haystack.len() < 16 { + return find_small(packed, haystack); + } + + let mut offset = 0; + + if let Some(misaligned) = Misalignment::new(haystack) { + if let Some(location) = packed.cmpestrm(misaligned.leading, misaligned.leading_junk) { + // Since the masking operation covers an entire + // 16-byte chunk, we have to see if the match occurred + // somewhere *after* our data + if location < haystack.len() { + return Some(location); + } + } + + haystack = &haystack[misaligned.bytes_until_alignment..]; + offset += misaligned.bytes_until_alignment; + } + + // TODO: try removing the 16-byte loop and check the disasm + let n_complete_chunks = haystack.len() / BYTES_PER_OPERATION; + + // Getting the pointer once before the loop avoids the + // overhead of manipulating the length of the slice inside the + // loop. + let mut haystack_ptr = haystack.as_ptr(); + let mut chunk_offset = 0; + for _ in 0..n_complete_chunks { + if let Some(location) = packed.cmpestri(haystack_ptr, BYTES_PER_OPERATION as i32) { + return Some(offset + chunk_offset + location); + } + + haystack_ptr = haystack_ptr.offset(BYTES_PER_OPERATION as isize); + chunk_offset += BYTES_PER_OPERATION; + } + haystack = &haystack[chunk_offset..]; + offset += chunk_offset; + + // No data left to search + if haystack.is_empty() { + return None; + } + + find_small(packed, haystack).map(|loc| loc + offset) +} + +struct PackedCompare(T); +impl PackedCompare +where + T: PackedCompareControl, +{ + #[inline] + #[target_feature(enable = "sse4.2")] + unsafe fn cmpestrm(&self, haystack: &[u8], leading_junk: usize) -> Option { + // TODO: document why this is ok + let haystack = _mm_loadu_si128(haystack.as_ptr() as *const __m128i); + + let mask = _mm_cmpestrm( + self.0.needle(), + self.0.needle_len(), + haystack, + BYTES_PER_OPERATION as i32, + CONTROL_BYTE, + ); + let mask = _mm_extract_epi16(mask, 0) as u16; + + if mask.trailing_zeros() < 16 { + let mut mask = mask; + // Byte: 7 6 5 4 3 2 1 0 + // Str : &[0, 1, 2, 3, ...] + // + // Bit-0 corresponds to Str-0; shifting to the right + // removes the parts of the string that don't belong to + // us. + mask >>= leading_junk; + // The first 1, starting from Bit-0 and going to Bit-7, + // denotes the position of the first match. + if mask == 0 { + // All of our matches were before the slice started + None + } else { + let first_match = mask.trailing_zeros() as usize; + debug_assert!(first_match < 16); + Some(first_match) + } + } else { + None + } + } + + #[inline] + #[target_feature(enable = "sse4.2")] + unsafe fn cmpestri(&self, haystack: *const u8, haystack_len: i32) -> Option { + debug_assert!( + (1..=16).contains(&haystack_len), + "haystack_len was {}", + haystack_len, + ); + + // TODO: document why this is ok + let haystack = _mm_loadu_si128(haystack as *const __m128i); + + let location = _mm_cmpestri( + self.0.needle(), + self.0.needle_len(), + haystack, + haystack_len, + CONTROL_BYTE, + ); + + if location < 16 { + Some(location as usize) + } else { + None + } + } +} + +#[derive(Debug)] +struct Misalignment<'a> { + leading: &'a [u8], + leading_junk: usize, + bytes_until_alignment: usize, +} + +impl<'a> Misalignment<'a> { + /// # Cases + /// + /// 0123456789ABCDEF + /// |--| < 1. + /// |--| < 2. + /// |--| < 3. + /// |----| < 4. + /// + /// 1. The input slice is aligned. + /// 2. The input slice is unaligned and is completely within the 16-byte chunk. + /// 3. The input slice is unaligned and touches the boundary of the 16-byte chunk. + /// 4. The input slice is unaligned and crosses the boundary of the 16-byte chunk. + #[inline] + fn new(haystack: &[u8]) -> Option { + let aligned_start = ((haystack.as_ptr() as usize) & !0xF) as *const u8; + + // If we are already aligned, there's nothing to do + if aligned_start == haystack.as_ptr() { + return None; + } + + let aligned_end = unsafe { aligned_start.offset(BYTES_PER_OPERATION as isize) }; + + let leading_junk = haystack.as_ptr() as usize - aligned_start as usize; + let leading_len = min(haystack.len() + leading_junk, BYTES_PER_OPERATION); + + let leading = unsafe { slice::from_raw_parts(aligned_start, leading_len) }; + + let bytes_until_alignment = if leading_len == BYTES_PER_OPERATION { + aligned_end as usize - haystack.as_ptr() as usize + } else { + haystack.len() + }; + + Some(Misalignment { + leading, + leading_junk, + bytes_until_alignment, + }) + } +} + +#[derive(Copy, Clone)] +pub struct Bytes { + needle: __m128i, + needle_len: i32, +} + +impl Bytes { + pub fn new(bytes: [u8; 16], needle_len: i32) -> Self { + Bytes { + needle: unsafe { TransmuteToSimd { bytes }.simd }, + needle_len, + } + } + + #[inline] + #[target_feature(enable = "sse4.2")] + pub unsafe fn find(&self, haystack: &[u8]) -> Option { + find(PackedCompare::<_, 0>(self), haystack) + } + + pub fn iter(self, haystack: &[u8]) -> BytesIter<'_> { + BytesIter { + bytes: self, + haystack, + offset: 0, + } + } +} + +pub struct BytesIter<'a> { + bytes: Bytes, + haystack: &'a [u8], + offset: usize, +} + +impl Iterator for BytesIter<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + let next_offset = unsafe { self.bytes.find(self.haystack) }; + match next_offset { + Some(i) => { + self.haystack = &self.haystack[i + 1..]; + let res = self.offset + i; + self.offset = res + 1; + Some(res) + } + None => { + self.haystack = &[]; + None + } + } + } +} + +impl<'b> PackedCompareControl for &'b Bytes { + #[inline] + fn needle(&self) -> __m128i { + self.needle + } + + #[inline] + fn needle_len(&self) -> i32 { + self.needle_len + } +} + +pub struct ByteSubstring<'a> { + complete_needle: &'a [u8], + needle: __m128i, + needle_len: i32, +} + +impl<'a> ByteSubstring<'a> { + pub fn new(needle: &'a [u8]) -> Self { + let mut simd_needle = [0; 16]; + let len = if simd_needle.len() < needle.len() { + simd_needle.len() + } else { + needle.len() + }; + let mut i = 0; + while i < len { + simd_needle[i] = needle[i]; + i += 1; + } + ByteSubstring { + complete_needle: needle, + needle: unsafe { TransmuteToSimd { bytes: simd_needle }.simd }, + needle_len: len as i32, + } + } + + #[cfg(feature = "pattern")] + pub fn needle_len(&self) -> usize { + self.complete_needle.len() + } + + #[inline] + #[target_feature(enable = "sse4.2")] + pub unsafe fn find(&self, haystack: &[u8]) -> Option { + let mut offset = 0; + + while let Some(idx) = find( + PackedCompare::<_, _SIDD_CMP_EQUAL_ORDERED>(self), + &haystack[offset..], + ) { + let abs_offset = offset + idx; + // Found a match, but is it really? + if haystack[abs_offset..].starts_with(self.complete_needle) { + return Some(abs_offset); + } + + // Skip past this false positive + offset += idx + 1; + } + + None + } +} + +impl<'a, 'b> PackedCompareControl for &'b ByteSubstring<'a> { + #[inline] + fn needle(&self) -> __m128i { + self.needle + } + + #[inline] + fn needle_len(&self) -> i32 { + self.needle_len + } +} diff --git a/tests/top_level.rs b/tests/top_level.rs new file mode 100644 index 0000000..e453592 --- /dev/null +++ b/tests/top_level.rs @@ -0,0 +1,442 @@ +use jetscii::{bytes, ByteSubstring, Bytes, BytesConst}; +use lazy_static::lazy_static; +use memmap2::MmapMut; +use proptest::prelude::*; +use region::Protection; +use std::{fmt, str}; + +lazy_static! { + static ref SPACE: BytesConst = bytes!(b' '); + static ref XML_DELIM_3: BytesConst = bytes!(b'<', b'>', b'&'); + static ref XML_DELIM_5: BytesConst = bytes!(b'<', b'>', b'&', b'\'', b'"'); +} + +trait SliceFindPolyfill { + fn find_any(&self, needles: &[T]) -> Option; + fn find_seq(&self, needle: &[T]) -> Option; +} + +impl SliceFindPolyfill for [T] +where + T: PartialEq, +{ + fn find_any(&self, needles: &[T]) -> Option { + self.iter().position(|c| needles.contains(c)) + } + + fn find_seq(&self, needle: &[T]) -> Option { + (0..self.len()).find(|&l| self[l..].starts_with(needle)) + } +} + +struct Haystack { + data: Vec, + start: usize, +} + +impl Haystack { + fn without_start(&self) -> &[u8] { + &self.data + } + + fn with_start(&self) -> &[u8] { + &self.data[self.start..] + } +} + +// Knowing the address of the data can be important +impl fmt::Debug for Haystack { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Haystack") + .field("data", &self.data) + .field("(addr)", &self.data.as_ptr()) + .field("start", &self.start) + .finish() + } +} + +/// Creates a set of bytes and an offset inside them. Allows +/// checking arbitrary memory offsets, not just where the +/// allocator placed a value. +fn haystack() -> BoxedStrategy { + any::>() + .prop_flat_map(|data| { + let len = 0..=data.len(); + (Just(data), len) + }) + .prop_map(|(data, start)| Haystack { data, start }) + .boxed() +} + +#[derive(Debug)] +struct Needle { + data: [u8; 16], + len: usize, +} + +impl Needle { + fn as_slice(&self) -> &[u8] { + &self.data[..self.len] + } +} + +/// Creates an array and the number of valid values +fn needle() -> BoxedStrategy { + (any::<[u8; 16]>(), 0..=16_usize) + .prop_map(|(data, len)| Needle { data, len }) + .boxed() +} + +proptest! { + #[test] + fn works_as_find_does_for_up_to_and_including_16_bytes( + (haystack, needle) in (haystack(), needle()) + ) { + let haystack = haystack.without_start(); + + let us = Bytes::new(needle.data, needle.len as i32, |b| needle.as_slice().contains(&b)).find(haystack); + let them = haystack.find_any(needle.as_slice()); + assert_eq!(us, them); + } + + #[test] + fn iter_works( + (haystack, needle) in (haystack(), needle()) + ) { + let haystack = haystack.without_start(); + + let bytes = Bytes::new(needle.data, needle.len as i32, |b| needle.as_slice().contains(&b)); + let mut us = bytes.iter(haystack); + let mut them = haystack.iter().enumerate().filter_map(|(i, b)| { + if needle.as_slice().contains(b) { + Some(i) + } else { + None + } + }); + loop { + match (us.next(), them.next()) { + (Some(us), Some(them)) => { assert_eq!(us, them); } + (Some(_), None) => panic!("them iterator ended before us"), + (None, Some(_)) => panic!("us iterator ended before them"), + (None, None) => break, + } + } + } + + #[test] + fn works_as_find_does_for_various_memory_offsets( + (needle, haystack) in (needle(), haystack()) + ) { + let haystack = haystack.with_start(); + + let us = Bytes::new(needle.data, needle.len as i32, |b| needle.as_slice().contains(&b)).find(haystack); + let them = haystack.find_any(needle.as_slice()); + assert_eq!(us, them); + } +} + +#[test] +fn can_search_for_null_bytes() { + let null = bytes!(b'\0'); + assert_eq!(Some(1), null.find(b"a\0")); + assert_eq!(Some(0), null.find(b"\0")); + assert_eq!(None, null.find(b"")); +} + +#[test] +fn can_search_in_null_bytes() { + let a = bytes!(b'a'); + assert_eq!(Some(1), a.find(b"\0a")); + assert_eq!(None, a.find(b"\0")); +} + +#[test] +fn space_is_found() { + // Since the simd algorithm operates on 16-byte chunks, it's + // important to cover tests around that boundary. Since 16 + // isn't that big of a number, we might as well do all of + // them. + + assert_eq!(Some(0), SPACE.find(b" ")); + assert_eq!(Some(1), SPACE.find(b"0 ")); + assert_eq!(Some(2), SPACE.find(b"01 ")); + assert_eq!(Some(3), SPACE.find(b"012 ")); + assert_eq!(Some(4), SPACE.find(b"0123 ")); + assert_eq!(Some(5), SPACE.find(b"01234 ")); + assert_eq!(Some(6), SPACE.find(b"012345 ")); + assert_eq!(Some(7), SPACE.find(b"0123456 ")); + assert_eq!(Some(8), SPACE.find(b"01234567 ")); + assert_eq!(Some(9), SPACE.find(b"012345678 ")); + assert_eq!(Some(10), SPACE.find(b"0123456789 ")); + assert_eq!(Some(11), SPACE.find(b"0123456789A ")); + assert_eq!(Some(12), SPACE.find(b"0123456789AB ")); + assert_eq!(Some(13), SPACE.find(b"0123456789ABC ")); + assert_eq!(Some(14), SPACE.find(b"0123456789ABCD ")); + assert_eq!(Some(15), SPACE.find(b"0123456789ABCDE ")); + assert_eq!(Some(16), SPACE.find(b"0123456789ABCDEF ")); + assert_eq!(Some(17), SPACE.find(b"0123456789ABCDEFG ")); +} + +#[test] +fn space_not_found() { + // Since the simd algorithm operates on 16-byte chunks, it's + // important to cover tests around that boundary. Since 16 + // isn't that big of a number, we might as well do all of + // them. + + assert_eq!(None, SPACE.find(b"")); + assert_eq!(None, SPACE.find(b"0")); + assert_eq!(None, SPACE.find(b"01")); + assert_eq!(None, SPACE.find(b"012")); + assert_eq!(None, SPACE.find(b"0123")); + assert_eq!(None, SPACE.find(b"01234")); + assert_eq!(None, SPACE.find(b"012345")); + assert_eq!(None, SPACE.find(b"0123456")); + assert_eq!(None, SPACE.find(b"01234567")); + assert_eq!(None, SPACE.find(b"012345678")); + assert_eq!(None, SPACE.find(b"0123456789")); + assert_eq!(None, SPACE.find(b"0123456789A")); + assert_eq!(None, SPACE.find(b"0123456789AB")); + assert_eq!(None, SPACE.find(b"0123456789ABC")); + assert_eq!(None, SPACE.find(b"0123456789ABCD")); + assert_eq!(None, SPACE.find(b"0123456789ABCDE")); + assert_eq!(None, SPACE.find(b"0123456789ABCDEF")); + assert_eq!(None, SPACE.find(b"0123456789ABCDEFG")); +} + +#[test] +fn works_on_nonaligned_beginnings() { + // We have special code for strings that don't lie on 16-byte + // boundaries. Since allocation seems to happen on these + // boundaries by default, let's walk around a bit. + + let s = b"0123456789ABCDEF ".to_vec(); + + assert_eq!(Some(16), SPACE.find(&s[0..])); + assert_eq!(Some(15), SPACE.find(&s[1..])); + assert_eq!(Some(14), SPACE.find(&s[2..])); + assert_eq!(Some(13), SPACE.find(&s[3..])); + assert_eq!(Some(12), SPACE.find(&s[4..])); + assert_eq!(Some(11), SPACE.find(&s[5..])); + assert_eq!(Some(10), SPACE.find(&s[6..])); + assert_eq!(Some(9), SPACE.find(&s[7..])); + assert_eq!(Some(8), SPACE.find(&s[8..])); + assert_eq!(Some(7), SPACE.find(&s[9..])); + assert_eq!(Some(6), SPACE.find(&s[10..])); + assert_eq!(Some(5), SPACE.find(&s[11..])); + assert_eq!(Some(4), SPACE.find(&s[12..])); + assert_eq!(Some(3), SPACE.find(&s[13..])); + assert_eq!(Some(2), SPACE.find(&s[14..])); + assert_eq!(Some(1), SPACE.find(&s[15..])); + assert_eq!(Some(0), SPACE.find(&s[16..])); + assert_eq!(None, SPACE.find(&s[17..])); +} + +#[test] +fn misalignment_does_not_cause_a_false_positive_before_start() { + const AAAA: u8 = 0x01; + + let needle = Needle { + data: [ + AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ], + len: 1, + }; + let haystack = Haystack { + data: vec![ + AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ], + start: 1, + }; + + let haystack = haystack.with_start(); + + // Needs to trigger the misalignment code + assert_ne!(0, (haystack.as_ptr() as usize) % 16); + // There are 64 bits in the mask and we check to make sure the + // result is less than the haystack + assert!(haystack.len() > 64); + + let us = Bytes::new(needle.data, needle.len as i32, |b| { + needle.as_slice().contains(&b) + }) + .find(haystack); + assert_eq!(None, us); +} + +#[test] +fn xml_delim_3_is_found() { + assert_eq!(Some(0), XML_DELIM_3.find(b"<")); + assert_eq!(Some(0), XML_DELIM_3.find(b">")); + assert_eq!(Some(0), XML_DELIM_3.find(b"&")); + assert_eq!(None, XML_DELIM_3.find(b"")); +} + +#[test] +fn xml_delim_5_is_found() { + assert_eq!(Some(0), XML_DELIM_5.find(b"<")); + assert_eq!(Some(0), XML_DELIM_5.find(b">")); + assert_eq!(Some(0), XML_DELIM_5.find(b"&")); + assert_eq!(Some(0), XML_DELIM_5.find(b"'")); + assert_eq!(Some(0), XML_DELIM_5.find(b"\"")); + assert_eq!(None, XML_DELIM_5.find(b"")); +} + +#[test] +fn do_not_find_zeros_after_end() { + // The simd algorithm will end up with a partial vector when the haystack isn't + // a multiple of 16 bytes. The rest of the vector will be filled with zeros. + // Ensure we don't find a match in the remainder of that vector, which would be considered + // after the end of the haystack. + let needle = bytes!(b'\0', b'*'); + let haystack = b"123456"; + assert_eq!(None, needle.find(haystack)); +} + +proptest! { + #[test] + fn works_as_find_does_for_byte_substrings( + (needle, haystack) in (any::>(), any::>()) + ) { + if !needle.is_empty() { + let us = { + let s = ByteSubstring::new(&needle); + s.find(&haystack) + }; + let them = haystack.find_seq(&needle); + assert_eq!(us, them); + } + } +} + +#[test] +fn byte_substring_is_found() { + let substr = ByteSubstring::new(b"zz"); + assert_eq!(Some(0), substr.find(b"zz")); + assert_eq!(Some(1), substr.find(b"0zz")); + assert_eq!(Some(2), substr.find(b"01zz")); + assert_eq!(Some(3), substr.find(b"012zz")); + assert_eq!(Some(4), substr.find(b"0123zz")); + assert_eq!(Some(5), substr.find(b"01234zz")); + assert_eq!(Some(6), substr.find(b"012345zz")); + assert_eq!(Some(7), substr.find(b"0123456zz")); + assert_eq!(Some(8), substr.find(b"01234567zz")); + assert_eq!(Some(9), substr.find(b"012345678zz")); + assert_eq!(Some(10), substr.find(b"0123456789zz")); + assert_eq!(Some(11), substr.find(b"0123456789Azz")); + assert_eq!(Some(12), substr.find(b"0123456789ABzz")); + assert_eq!(Some(13), substr.find(b"0123456789ABCzz")); + assert_eq!(Some(14), substr.find(b"0123456789ABCDzz")); + assert_eq!(Some(15), substr.find(b"0123456789ABCDEzz")); + assert_eq!(Some(16), substr.find(b"0123456789ABCDEFzz")); + assert_eq!(Some(17), substr.find(b"0123456789ABCDEFGzz")); +} + +#[test] +fn byte_substring_is_not_found() { + let substr = ByteSubstring::new(b"zz"); + assert_eq!(None, substr.find(b"")); + assert_eq!(None, substr.find(b"0")); + assert_eq!(None, substr.find(b"01")); + assert_eq!(None, substr.find(b"012")); + assert_eq!(None, substr.find(b"0123")); + assert_eq!(None, substr.find(b"01234")); + assert_eq!(None, substr.find(b"012345")); + assert_eq!(None, substr.find(b"0123456")); + assert_eq!(None, substr.find(b"01234567")); + assert_eq!(None, substr.find(b"012345678")); + assert_eq!(None, substr.find(b"0123456789")); + assert_eq!(None, substr.find(b"0123456789A")); + assert_eq!(None, substr.find(b"0123456789AB")); + assert_eq!(None, substr.find(b"0123456789ABC")); + assert_eq!(None, substr.find(b"0123456789ABCD")); + assert_eq!(None, substr.find(b"0123456789ABCDE")); + assert_eq!(None, substr.find(b"0123456789ABCDEF")); + assert_eq!(None, substr.find(b"0123456789ABCDEFG")); +} + +#[test] +fn byte_substring_has_false_positive() { + // The PCMPESTRI instruction will mark the "a" before "ab" as + // a match because it cannot look beyond the 16 byte window + // of the haystack. We need to double-check any match to + // ensure it completely matches. + + let substr = ByteSubstring::new(b"ab"); + assert_eq!(Some(16), substr.find(b"aaaaaaaaaaaaaaaaab")) + // this "a" is a false positive ~~~~~~~~~~~~~~~^ +} + +#[test] +fn byte_substring_needle_is_longer_than_16_bytes() { + let needle = b"0123456789abcdefg"; + let haystack = b"0123456789abcdefgh"; + assert_eq!(Some(0), ByteSubstring::new(needle).find(haystack)); +} + +fn with_guarded_string(value: &str, f: impl FnOnce(&str)) { + // Allocate a string that ends directly before a + // read-protected page. + + let page_size = region::page::size(); + assert!(value.len() <= page_size); + + // Map two rw-accessible pages of anonymous memory + let mut mmap = MmapMut::map_anon(2 * page_size).unwrap(); + + let (first_page, second_page) = mmap.split_at_mut(page_size); + + // Prohibit any access to the second page, so that any attempt + // to read or write it would trigger a segfault + unsafe { + region::protect(second_page.as_ptr(), page_size, Protection::NONE).unwrap(); + } + + // Copy bytes to the end of the first page + let dest = &mut first_page[page_size - value.len()..]; + dest.copy_from_slice(value.as_bytes()); + f(unsafe { str::from_utf8_unchecked(dest) }); +} + +#[test] +fn works_at_page_boundary() { + // PCMPxSTRx instructions are known to read 16 bytes at a + // time. This behaviour may cause accidental segfaults by + // reading past the page boundary. + // + // For now, this test failing crashes the whole test + // suite. This could be fixed by setting a custom signal + // handler, though Rust lacks such facilities at the moment. + + // Allocate a 16-byte string at page boundary. To verify this + // test, set protect=false to prevent segfaults. + with_guarded_string("0123456789abcdef", |text| { + // Will search for the last char + let needle = bytes!(b'f'); + + // Check all suffixes of our 16-byte string + for offset in 0..text.len() { + let tail = &text[offset..]; + assert_eq!(Some(tail.len() - 1), needle.find(tail.as_bytes())); + } + }); +} + +#[test] +fn does_not_access_memory_after_haystack_when_haystack_is_multiple_of_16_bytes_and_no_match() { + // For now, this test failing crashes the whole test + // suite. This could be fixed by setting a custom signal + // handler, though Rust lacks such facilities at the moment. + with_guarded_string("0123456789abcdef", |text| { + // Will search for a char not present + let needle = bytes!(b'z'); + + assert_eq!(None, needle.find(text.as_bytes())); + }); +}