diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 8b6404ee33..86824d8c98 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -176,9 +176,7 @@ impl PyNormalizer { /// :obj:`str`: A string after normalization #[pyo3(text_signature = "(self, sequence)")] fn normalize_str(&self, sequence: &str) -> PyResult { - let mut normalized = NormalizedString::from(sequence); - ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; - Ok(normalized.get().to_owned()) + ToPyResult(self.normalizer.normalize_str(sequence)).into() } fn __repr__(&self) -> PyResult { diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index 41fd416156..650ec2e2fb 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -10,6 +10,42 @@ pub struct ByteLevel; static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); +/// Pre-encoded UTF-8 lookup table for [`ByteLevel::normalize_str`]. +/// +/// The byte-level normalizer maps every input byte (0x00–0xFF) to a specific +/// Unicode character. For example: +/// +/// | Input byte | Char | Code point | UTF-8 bytes | Length | +/// |------------|---------|------------|-------------|--------| +/// | `0x6B` (k) | 'k' | U+006B | `[6B]` | 1 | +/// | `0x20` ( ) | 'Ġ' | U+0120 | `[C4, A0]` | 2 | +/// | `0x0D` (CR)| 'č' | U+010D | `[C4, 8D]` | 2 | +/// +/// Without this table, each byte requires a `HashMap` lookup +/// followed by `char::encode_utf8` to write it into the output `String`. +/// +/// This table pre-computes the UTF-8 encoding once at startup, so the hot +/// path is just `out.extend_from_slice(&entry.bytes[..entry.len])` — a +/// direct memcpy with no hashing and no per-char encoding. +struct Utf8Entry { + /// The UTF-8 encoding of the byte-level char (at most 2 bytes for this + /// particular mapping, since all chars are ≤ U+0122). + bytes: [u8; 2], + /// Number of valid bytes in `bytes` (1 or 2). + len: u8, +} + +static BYTES_CHAR_UTF8: LazyLock<[Utf8Entry; 256]> = LazyLock::new(|| { + let map = bytes_char(); + std::array::from_fn(|i| { + let c = map[&(i as u8)]; + let mut buf = [0u8; 2]; + let s = c.encode_utf8(&mut buf); + let len = s.len() as u8; + Utf8Entry { bytes: buf, len } + }) +}); + impl Default for ByteLevel { fn default() -> Self { Self::new() @@ -45,6 +81,19 @@ impl Normalizer for ByteLevel { } Ok(()) } + + /// Fast path: map each byte to its byte-level char without alignment tracking. + /// Uses a pre-encoded UTF-8 lookup table — no HashMap, no per-char encoding. + fn normalize_str(&self, s: &str) -> Result { + let table = &*BYTES_CHAR_UTF8; + let mut out = Vec::with_capacity(s.len() * 2); + for &b in s.as_bytes() { + let entry = &table[b as usize]; + out.extend_from_slice(&entry.bytes[..entry.len as usize]); + } + // SAFETY: every entry in the table is valid UTF-8 (encoded from a char). + Ok(unsafe { String::from_utf8_unchecked(out) }) + } } #[cfg(test)] @@ -171,4 +220,15 @@ mod tests { ] ); } + + #[test] + fn normalize_str_matches_normalize() { + let bl = ByteLevel::new(); + for input in &["Hello", "Hello 我今天能为你做什么", "", "abc\x00\x01\x7f"] { + let mut ns = NormalizedString::from(*input); + bl.normalize(&mut ns).unwrap(); + let fast = bl.normalize_str(input).unwrap(); + assert_eq!(ns.get(), fast, "mismatch for input: {input:?}"); + } + } } diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index f400f13da9..678002eefc 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -200,6 +200,25 @@ impl Normalizer for NormalizerWrapper { Self::ByteLevel(lc) => lc.normalize(normalized), } } + + fn normalize_str(&self, s: &str) -> crate::Result { + match self { + Self::BertNormalizer(bn) => bn.normalize_str(s), + Self::StripNormalizer(sn) => sn.normalize_str(s), + Self::StripAccents(sn) => sn.normalize_str(s), + Self::NFC(nfc) => nfc.normalize_str(s), + Self::NFD(nfd) => nfd.normalize_str(s), + Self::NFKC(nfkc) => nfkc.normalize_str(s), + Self::NFKD(nfkd) => nfkd.normalize_str(s), + Self::Sequence(sequence) => sequence.normalize_str(s), + Self::Lowercase(lc) => lc.normalize_str(s), + Self::Nmt(lc) => lc.normalize_str(s), + Self::Precompiled(lc) => lc.normalize_str(s), + Self::Replace(lc) => lc.normalize_str(s), + Self::Prepend(lc) => lc.normalize_str(s), + Self::ByteLevel(lc) => lc.normalize_str(s), + } + } } impl_enum_from!(BertNormalizer, NormalizerWrapper, BertNormalizer); diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 4e318c2599..1f8a2087c7 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -21,6 +21,13 @@ impl Normalizer for Prepend { } Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + if s.is_empty() { + Ok(String::new()) + } else { + Ok(format!("{}{s}", self.prepend)) + } + } } #[cfg(test)] diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 19f5ff314d..a511897834 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -38,6 +38,11 @@ impl Normalizer for Strip { Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + let s = if self.strip_left { s.trim_start() } else { s }; + let s = if self.strip_right { s.trim_end() } else { s }; + Ok(s.to_owned()) + } } // This normalizer removes combining marks from a normalized string @@ -53,6 +58,9 @@ impl Normalizer for StripAccents { normalized.filter(|c| !is_combining_mark(c)); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.chars().filter(|c| !is_combining_mark(*c)).collect()) + } } #[cfg(test)] diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 502b4239b4..a0a4854fec 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,5 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; +use unicode_normalization_alignments::UnicodeNormalization; #[derive(Default, Copy, Clone, Debug)] #[macro_rules_attribute(impl_serde_type!)] @@ -9,6 +10,9 @@ impl Normalizer for NFD { normalized.nfd(); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.nfd().map(|(c, _)| c).collect()) + } } #[derive(Default, Copy, Clone, Debug)] @@ -19,6 +23,9 @@ impl Normalizer for NFKD { normalized.nfkd(); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.nfkd().map(|(c, _)| c).collect()) + } } #[derive(Default, Copy, Clone, Debug)] @@ -29,6 +36,9 @@ impl Normalizer for NFC { normalized.nfc(); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.nfc().map(|(c, _)| c).collect()) + } } #[derive(Default, Copy, Clone, Debug)] @@ -39,6 +49,9 @@ impl Normalizer for NFKC { normalized.nfkc(); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.nfkc().map(|(c, _)| c).collect()) + } } fn do_nmt(normalized: &mut NormalizedString) { @@ -80,6 +93,21 @@ impl Normalizer for Nmt { do_nmt(normalized); Ok(()) } + fn normalize_str(&self, s: &str) -> Result { + Ok(s.chars() + .filter(|c| { + !matches!( + *c as u32, + 0x0001..=0x0008 | 0x000B | 0x000E..=0x001F | 0x007F | 0x008F | 0x009F + ) + }) + .map(|c| match c as u32 { + 0x0009 | 0x000A | 0x000C | 0x000D | 0x1680 | 0x200B..=0x200F | 0x2028 + | 0x2029 | 0x2581 | 0xFEFF | 0xFFFD => ' ', + _ => c, + }) + .collect()) + } } #[cfg(test)] diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index 1e33cc791a..7e0fa3f3ee 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -46,6 +46,14 @@ impl Normalizer for Sequence { } Ok(()) } + + fn normalize_str(&self, s: &str) -> Result { + let mut result = s.to_owned(); + for normalizer in &self.normalizers { + result = normalizer.normalize_str(&result)?; + } + Ok(result) + } } /// Lowercases the input @@ -57,4 +65,8 @@ impl Normalizer for Lowercase { normalized.lowercase(); Ok(()) } + + fn normalize_str(&self, s: &str) -> Result { + Ok(s.to_lowercase()) + } } diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index ca7bae5580..586f84d963 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -315,9 +315,7 @@ impl AddedVocabulary { if token.normalized { if let Some(n) = normalizer { - let mut s = NormalizedString::from(token.content.as_ref()); - n.normalize(&mut s)?; - let normed = s.get().to_string(); + let normed = n.normalize_str(&token.content)?; if normed != token.content { self.normalized_cache.insert(new_id, normed); } @@ -360,9 +358,7 @@ impl AddedVocabulary { for (id, token) in &self.added_tokens_map_r { if token.normalized { if let Some(n) = normalizer { - let mut s = NormalizedString::from(token.content.as_ref()); - n.normalize(&mut s)?; - let normed = s.get().to_string(); + let normed = n.normalize_str(&token.content)?; if normed != token.content { self.normalized_cache.insert(*id, normed); } @@ -562,6 +558,39 @@ impl AddedVocabulary { pretokenized } + + /// Like [`extract_and_normalize`] but uses [`Normalizer::normalize_str`] + /// instead of [`Normalizer::normalize`], skipping alignment tracking. + /// + /// This is used by `encode_fast` where offsets are not needed. The + /// normalization step avoids building per-byte alignment vectors, which + /// saves O(n) allocations per split. + pub fn extract_and_normalize_fast( + &self, + normalizer: Option<&N>, + sequence: &str, + ) -> PreTokenizedString { + let mut pretokenized: PreTokenizedString = sequence.into(); + + // 1. Extract non-normalized tokens from the raw string + pretokenized + .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie))) + .expect("AddedVocabulary bad split"); + + // 2. Normalize remaining pieces via normalize_str (no alignment tracking) + // and extract normalized tokens + pretokenized + .split(|_, mut sequence| { + if let Some(n) = normalizer { + let normed = n.normalize_str(sequence.get())?; + sequence.set_normalized(normed); + } + Ok(self.split_with_indices(sequence, &self.split_normalized_trie)) + }) + .expect("AddedVocabulary bad split"); + + pretokenized + } } impl Default for AddedVocabulary { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 8e282fba28..74d85d0706 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -55,6 +55,17 @@ pub type Offsets = (usize, usize); /// Takes care of pre-processing strings. pub trait Normalizer: Sync { fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>; + + /// Normalize a plain string, returning the result without tracking offsets. + /// + /// The default implementation allocates a full [`NormalizedString`] (with + /// alignment vectors). Normalizers that can produce their output more + /// cheaply should override this to avoid the overhead. + fn normalize_str(&self, s: &str) -> Result { + let mut n = NormalizedString::from(s); + self.normalize(&mut n)?; + Ok(n.get().to_owned()) + } } /// The `PreTokenizer` is in charge of doing the pre-segmentation step. It splits the given string @@ -731,10 +742,15 @@ where type_id: u32, offsets_type: OffsetType, ) -> Result { + let fast = matches!(offsets_type, OffsetType::None); let encode = |is_pre_tokenized, subseq_idx, subseq| -> Result { - let normalized = self - .added_vocabulary - .extract_and_normalize(self.normalizer.as_ref(), subseq); + let normalized = if fast { + self.added_vocabulary + .extract_and_normalize_fast(self.normalizer.as_ref(), subseq) + } else { + self.added_vocabulary + .extract_and_normalize(self.normalizer.as_ref(), subseq) + }; let pre_tokenized = self.do_pre_tokenize(normalized)?; let subseq_encoding = self.do_tokenize( pre_tokenized, diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 5bebd5f7b4..fab640ae36 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -136,6 +136,18 @@ impl NormalizedString { &self.normalized } + /// Replace the normalized content without tracking alignments. + /// + /// This is significantly cheaper than going through `transform()` since it + /// skips the per-byte alignment bookkeeping. Use this when offset tracking + /// is not needed (e.g. `encode_fast`). + pub fn set_normalized(&mut self, new: String) { + // Build trivial 1:1 alignments so that slice() still works for + // splitting, but no real offset mapping is preserved. + self.alignments = new.as_bytes().iter().enumerate().map(|(i, _)| (i, i + 1)).collect(); + self.normalized = new; + } + /// Return the original string pub fn get_original(&self) -> &str { &self.original