diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 40b273ac4..1dd0f7724 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["Anthony MOI ", "Nicolas Patry "] +authors = [ + "Anthony MOI ", + "Nicolas Patry ", +] edition = "2018" name = "tokenizers" version = "0.22.3-dev.0" @@ -13,7 +16,14 @@ description = """ Provides an implementation of today's most used tokenizers, with a focus on performances and versatility. """ -exclude = [ "rust-toolchain", "target/*", "Cargo.lock", "benches/*.txt", "benches/*.json", "data/*" ] +exclude = [ + "rust-toolchain", + "target/*", + "Cargo.lock", + "benches/*.txt", + "benches/*.json", + "data/*", +] [package.metadata.docs.rs] all-features = true @@ -48,6 +58,10 @@ name = "added_vocab_deserialize" required-features = ["http"] harness = false +[[bench]] +name = "parallel_pretok_benchmark" +harness = false + [dependencies] rand = "0.9" onig = { version = "6.5.1", default-features = false, optional = true } @@ -55,24 +69,26 @@ regex = "1.10" regex-syntax = "0.8" rayon = "1.10" rayon-cond = "0.4" -serde = { version = "1.0", features = [ "derive" ] } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" unicode-normalization-alignments = "0.1" unicode_categories = "0.1" unicode-segmentation = "1.11" -indicatif = {version = "0.18", optional = true} +indicatif = { version = "0.18", optional = true } itertools = "0.14" log = "0.4" derive_builder = "0.20" spm_precompiled = "0.1.3" -hf-hub = { version = "0.4.1", features = ["ureq"], default-features = false, optional = true } +hf-hub = { version = "0.4.1", features = [ + "ureq", +], default-features = false, optional = true } aho-corasick = "1.1" paste = "1.0.14" macro_rules_attribute = "0.2.0" thiserror = "2" -fancy-regex = { version = "0.17", optional = true} +fancy-regex = { version = "0.17", optional = true } getrandom = { version = "0.3" } -esaxx-rs = { version = "0.1.10", default-features = false, features=[]} +esaxx-rs = { version = "0.1.10", default-features = false, features = [] } monostate = "0.1.12" ahash = { version = "0.8.11", features = ["serde"] } dary_heap = { version = "0.3.6", features = ["serde"] } @@ -99,4 +115,3 @@ lto = "fat" [[example]] name = "encode_batch" required-features = ["http"] - diff --git a/tokenizers/benches/parallel_pretok_benchmark.rs b/tokenizers/benches/parallel_pretok_benchmark.rs new file mode 100644 index 000000000..66448f1b1 --- /dev/null +++ b/tokenizers/benches/parallel_pretok_benchmark.rs @@ -0,0 +1,102 @@ +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion, Throughput}; +use std::hint::black_box; +use tokenizers::pattern::Pattern; +use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::utils::SysRegex; +use tokenizers::{PreTokenizedString, PreTokenizer}; + +/// GPT-2 byte-level regex pattern — the most common pre-tokenization regex +fn gpt2_regex() -> SysRegex { + SysRegex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+") + .unwrap() +} + +fn bench_parallel_pretok(c: &mut Criterion) { + let data = std::fs::read_to_string("data/big.txt").unwrap(); + + // --- Raw find_matches: sequential vs parallel --- + { + let mut group = c.benchmark_group("find_matches"); + group.throughput(Throughput::Bytes(data.len() as u64)); + + let re = gpt2_regex(); + + group.bench_function("sequential", |b| { + tokenizers::parallelism::set_parallelism(false); + b.iter(|| (&re).find_matches(black_box(&data)).unwrap()) + }); + + group.bench_function("parallel", |b| { + tokenizers::parallelism::set_parallelism(true); + b.iter(|| (&re).find_matches(black_box(&data)).unwrap()) + }); + + // Restore default + tokenizers::parallelism::set_parallelism(true); + group.finish(); + } + + // --- Full pre-tokenizer pipeline: sequential vs parallel --- + { + let mut group = c.benchmark_group("byte-level-pretok"); + group.throughput(Throughput::Bytes(data.len() as u64)); + + let pretok = ByteLevel::default(); + + group.bench_function("sequential", |b| { + tokenizers::parallelism::set_parallelism(false); + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(data.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + + group.bench_function("parallel", |b| { + tokenizers::parallelism::set_parallelism(true); + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(data.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + + tokenizers::parallelism::set_parallelism(true); + group.finish(); + } + + // --- Scaling by input size --- + { + let mut group = c.benchmark_group("parallel-pretok-scaling"); + let re = gpt2_regex(); + + for size in [1_000, 10_000, 100_000, 500_000] { + let input: String = data.chars().take(size).collect(); + group.throughput(Throughput::Bytes(input.len() as u64)); + + group.bench_with_input(BenchmarkId::new("sequential", size), &input, |b, input| { + tokenizers::parallelism::set_parallelism(false); + b.iter(|| (&re).find_matches(black_box(input.as_str())).unwrap()) + }); + + group.bench_with_input(BenchmarkId::new("parallel", size), &input, |b, input| { + tokenizers::parallelism::set_parallelism(true); + b.iter(|| (&re).find_matches(black_box(input.as_str())).unwrap()) + }); + } + + tokenizers::parallelism::set_parallelism(true); + group.finish(); + } +} + +criterion_group! { + name = parallel_pretok; + config = Criterion::default().sample_size(20); + targets = bench_parallel_pretok +} + +criterion_main!(parallel_pretok); diff --git a/tokenizers/src/tokenizer/pattern.rs b/tokenizers/src/tokenizer/pattern.rs index a2a2f1684..65fbedf2d 100644 --- a/tokenizers/src/tokenizer/pattern.rs +++ b/tokenizers/src/tokenizer/pattern.rs @@ -1,7 +1,13 @@ +use crate::parallelism::get_parallelism; use crate::utils::SysRegex; use crate::{Offsets, Result}; +use rayon::current_num_threads; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use regex::Regex; +const MIN_CHUNK_SIZE: usize = 8 * 1024; // 8KB +const CHUNK_OVERLAP: usize = 1024; // 1KB + /// Pattern used to split a NormalizedString pub trait Pattern { /// Slice the given string in a list of pattern match positions, with @@ -40,6 +46,9 @@ impl Pattern for &String { impl Pattern for &Regex { fn find_matches(&self, inside: &str) -> Result> { + if inside.len() > 2 * MIN_CHUNK_SIZE && get_parallelism() { + return parallel_find_matches_with_config(*self, inside, MIN_CHUNK_SIZE, CHUNK_OVERLAP); + } if inside.is_empty() { return Ok(vec![((0, 0), false)]); } @@ -62,6 +71,9 @@ impl Pattern for &Regex { impl Pattern for &SysRegex { fn find_matches(&self, inside: &str) -> Result> { + if inside.len() > 2 * MIN_CHUNK_SIZE && get_parallelism() { + return parallel_find_matches_with_config(*self, inside, MIN_CHUNK_SIZE, CHUNK_OVERLAP); + } if inside.is_empty() { return Ok(vec![((0, 0), false)]); } @@ -137,6 +149,190 @@ impl Pattern for Invert

{ } } +struct OverlappingChunks<'a> { + s: &'a str, + chunk_size: usize, + overlap: usize, + pos: usize, +} + +struct Chunk<'a> { + text: &'a str, + authority_start: usize, + authority_end: usize, +} + +impl<'a> OverlappingChunks<'a> { + fn new(s: &'a str, chunk_size: usize, overlap: usize) -> Self { + Self { + s, + chunk_size, + overlap, + pos: 0, + } + } +} + +impl<'a> Iterator for OverlappingChunks<'a> { + type Item = Chunk<'a>; + + fn next(&mut self) -> Option { + if self.pos >= self.s.len() { + return None; + } + + let authority_start = self.pos; + + let mut authority_end = (self.pos + self.chunk_size).min(self.s.len()); + while authority_end < self.s.len() && !self.s.is_char_boundary(authority_end) { + authority_end += 1; + } + + let mut chunk_end = (authority_end + self.overlap).min(self.s.len()); + while chunk_end < self.s.len() && !self.s.is_char_boundary(chunk_end) { + chunk_end += 1; + } + + self.pos = authority_end; + + Some(Chunk { + text: &self.s[authority_start..chunk_end], + authority_start, + authority_end, + }) + } +} + +fn parallel_find_matches_with_config( + pattern: P, + inside: &str, + min_chunk_size: usize, + chunk_overlap: usize, +) -> Result> { + if inside.len() <= 2 * min_chunk_size { + return pattern.find_matches(inside); + } + + let n_chunks = current_num_threads().min(inside.len() / min_chunk_size); + + // Split the string into overlapping chunks, find matches in each chunk in parallel + let chunks: Vec<_> = + OverlappingChunks::new(inside, inside.len() / n_chunks, chunk_overlap).collect(); + let matches: Vec> = chunks + .par_iter() + .map(|chunk| -> Result> { + let local_matches = pattern.find_matches(chunk.text)?; + Ok(local_matches + .into_iter() + .map(|((s, e), is_match)| { + ( + (s + chunk.authority_start, e + chunk.authority_start), + is_match, + ) + }) + .filter(|((s, _e), is_match)| *is_match && *s < chunk.authority_end) + .collect()) + }) + .collect::>>()?; + + // Merge results + let matches: Vec<_> = matches.into_iter().flatten().collect(); + let mut i = 0; + let mut merged = Vec::new(); + let mut prev_end = 0; + + while i < matches.len() { + let (s, e) = matches[i].0; + + if s >= prev_end { + // Normal match + if s > prev_end { + merged.push(((prev_end, s), false)); + } + merged.push(matches[i]); + prev_end = e; + i += 1; + } else { + // Ghost region, skip matches that start before prev_end + let mut max_ghost_end = 0; + while i < matches.len() && matches[i].0 .0 < prev_end { + max_ghost_end = max_ghost_end.max(matches[i].0 .1); + i += 1; + } + // If a ghost region extends past prev_end, last match was truncated, we need to fix + if max_ghost_end > prev_end { + if let Some(((trunc_start, trunc_end), _)) = merged.last_mut() { + if let Some((_, new_end)) = + find_one_from(&pattern, inside, *trunc_start, chunk_overlap)? + { + *trunc_end = new_end; + prev_end = new_end; + } + } + } + + if i < matches.len() && matches[i].0 .0 > prev_end { + let mut pos = prev_end; + while pos < inside.len() { + match find_one_from(&pattern, inside, pos, chunk_overlap)? { + Some((ms, me)) => { + if matches[i].0 == (ms, me) { + break; + } + if prev_end < ms { + merged.push(((prev_end, ms), false)); + } + merged.push(((ms, me), true)); + prev_end = me; + pos = me; + } + _ => break, + } + } + } + } + } + + if prev_end < inside.len() { + merged.push(((prev_end, inside.len()), false)); + } + + Ok(merged) +} + +fn find_one_from( + pattern: &P, + inside: &str, + from: usize, + chunk_overlap: usize, +) -> Result> { + for n in 1..=8 { + let mut window_end = (from + chunk_overlap * n).min(inside.len()); + while window_end < inside.len() && !inside.is_char_boundary(window_end) { + window_end += 1; + } + let window = &inside[from..window_end]; + let result = pattern + .find_matches(window)? + .into_iter() + .find(|(_, is_match)| *is_match) + .map(|((s, e), _)| (s + from, e + from)); + + match result { + Some((s, e)) if e < window_end => return Ok(Some((s, e))), + Some(_) if window_end < inside.len() => continue, + Some((s, e)) => return Ok(Some((s, e))), + None if window_end < inside.len() => continue, + None => return Ok(None), + } + } + Ok(pattern + .find_matches(&inside[from..])? + .into_iter() + .find(|(_, is_match)| *is_match) + .map(|((s, e), _)| (s + from, e + from))) +} + #[cfg(test)] mod tests { use super::*; @@ -218,4 +414,44 @@ mod tests { ); do_test!("aaa", &is_whitespace => vec![((0, 3), false)]); } + + #[test] + fn parallel_correctness() { + let patterns = vec![ + SysRegex::new(r"\s+").unwrap(), + SysRegex::new(r"\w+|[^\w\s]+").unwrap(), + SysRegex::new( + r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+", + ) + .unwrap(), + ]; + + let long_input = "NoSpacesAtAllInThisVeryLongWord repeated " + .repeat(20) + .trim() + .to_string(); + + let inputs: Vec<&str> = vec![ + "hello world foo bar baz", + "a b c d e f g h i j", + "Hello, world! This is a test. Numbers: 123, 456.", + "Unicode: café résumé naïve 日本語テスト", + "Short", + "", + &long_input, + ]; + + for pattern in &patterns { + for input in &inputs { + let sequential = pattern.find_matches(input).unwrap(); + let parallel = parallel_find_matches_with_config(pattern, input, 5, 5).unwrap(); + assert_eq!( + sequential, + parallel, + "Mismatch for input: '{}'", + &input[..input.len().min(50)] + ); + } + } + } }