diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2f560b7e3f..32a3981414 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,6 +1,7 @@ -use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; +use super::{trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; +use crate::utils::compact_vocab::CompactVocab; use crate::utils::iter::ResultShunt; use ahash::AHashMap; use serde_json::Value; @@ -15,7 +16,6 @@ use std::{ }; pub type Vocab = AHashMap; -type VocabR = AHashMap; pub type MergeMap = AHashMap; pub type Merges = Vec<(String, String)>; @@ -154,12 +154,9 @@ impl BpeBuilder { self.config.merges = m; } - let vocab_r = self - .config - .vocab - .iter() - .map(|(key, val)| (*val, key.to_owned())) - .collect(); + let vocab_r = CompactVocab::from_vocab( + self.config.vocab.iter().map(|(key, val)| (key.as_str(), *val)), + ); let cache = match self.config.cache_capacity { 0 => None, capacity => Some(Cache::new(capacity)), @@ -214,8 +211,8 @@ impl BpeBuilder { pub struct BPE { /// The vocabulary assigns a number to each token. pub(crate) vocab: Vocab, - /// Reversed vocabulary, to rebuild sentences. - pub(crate) vocab_r: VocabR, + /// Compact id→token store: all strings in one allocation, O(1) indexed lookup. + pub(crate) vocab_r: CompactVocab, /// Contains the mapping between Pairs and their (rank, new_id). pub(crate) merges: MergeMap, /// Contains the cache for optimizing the encoding step. @@ -469,7 +466,10 @@ impl BPE { fn word_to_tokens<'a>(&'a self, word: &'a Word) -> impl Iterator + 'a { word.get_chars_iter() .zip(word.get_offsets_iter()) - .map(move |(id, offsets)| Token::new(id, self.vocab_r[&id].clone(), offsets)) + .map(move |(id, offsets)| { + let s = self.vocab_r.get(id).expect("id missing from vocab_r").to_owned(); + Token::new(id, s, offsets) + }) } fn tokenize_with_cache(&self, sequence: &str) -> Result> { @@ -525,7 +525,7 @@ impl Model for BPE { } fn id_to_token(&self, id: u32) -> Option { - self.vocab_r.get(&id).cloned() + self.vocab_r.get(id).map(str::to_owned) } fn save(&self, folder: &Path, name: Option<&str>) -> Result> { @@ -539,8 +539,8 @@ impl Model for BPE { .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; - let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); - let serialized = serde_json::to_string(&order_vocab_iter)?; + // CompactVocab serializes as the same {"token": id} JSON object. + let serialized = serde_json::to_string(&self.vocab_r)?; vocab_file.write_all(serialized.as_bytes())?; // Write merges.txt @@ -564,7 +564,9 @@ impl Model for BPE { &merges .into_iter() .flat_map(|(pair, _)| { - format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes() + let a = self.vocab_r.get(pair.0).expect("merge token missing from vocab"); + let b = self.vocab_r.get(pair.1).expect("merge token missing from vocab"); + format!("{a} {b}\n").into_bytes() }) .collect::>()[..], )?; @@ -583,18 +585,14 @@ mod tests { use tempfile::NamedTempFile; #[test] - fn test_ordered_vocab_iter() { - let vocab_r: VocabR = [ - (0, "a".into()), - (1, "b".into()), - (2, "c".into()), - (3, "ab".into()), - ] - .iter() - .cloned() - .collect(); - let order_vocab_iter = OrderedVocabIter::new(&vocab_r); - let serialized = serde_json::to_string(&order_vocab_iter).unwrap(); + fn test_compact_vocab_serialization() { + let vocab_r = CompactVocab::from_vocab([ + ("a", 0u32), + ("b", 1), + ("c", 2), + ("ab", 3), + ]); + let serialized = serde_json::to_string(&vocab_r).unwrap(); assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}"); } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cf549445..ee811d4d72 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,4 +1,4 @@ -use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use super::{convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; use ahash::AHashMap; use serde::{ de::{Error, MapAccess, Visitor}, @@ -32,11 +32,15 @@ impl Serialize for BPE { merges.sort_unstable_by_key(|k| *k.1); let merges = merges .into_iter() - .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) + .map(|(pair, _)| { + let a = self.vocab_r.get(pair.0).expect("merge token missing").to_owned(); + let b = self.vocab_r.get(pair.1).expect("merge token missing").to_owned(); + (a, b) + }) .collect::>(); - let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); - model.serialize_field("vocab", &ordered_vocab)?; + // CompactVocab serializes as {"token": id} in ascending id order. + model.serialize_field("vocab", &self.vocab_r)?; model.serialize_field("merges", &merges)?; model.end() diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index df68c655e9..9793071120 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,6 +1,7 @@ #![allow(clippy::map_entry)] use super::{Pair, WithFirstLastIterator, Word, BPE}; +use crate::utils::compact_vocab::CompactVocab; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle}; @@ -612,11 +613,9 @@ impl BpeTrainer { // we have to look up the string in id_to_word because the key in word_to_id is a hash .map(|(_key, val)| (id_to_word[val as usize].to_string(), val)) .collect(); - model.vocab_r = model - .vocab - .iter() - .map(|(key, val)| (*val, key.to_owned())) - .collect(); + model.vocab_r = CompactVocab::from_vocab( + model.vocab.iter().map(|(key, val)| (key.as_str(), *val)), + ); model.merges = merges .into_iter() .enumerate() diff --git a/tokenizers/src/utils/compact_vocab.rs b/tokenizers/src/utils/compact_vocab.rs new file mode 100644 index 0000000000..2538913335 --- /dev/null +++ b/tokenizers/src/utils/compact_vocab.rs @@ -0,0 +1,225 @@ +/// A compact, cache-friendly id→token store. +/// +/// All token strings are concatenated into a single contiguous byte buffer and +/// indexed by a dense array of `u32` byte offsets. Reverse lookup (id → token) +/// is two array reads — no hash-table indirection, no per-string heap +/// allocation — and sequential id scans (serialization, iteration) stay in the +/// same cache lines. +/// +/// # Layout +/// +/// ```text +/// data: [ h e l l o w o r l d ] +/// offsets: [ 0, 5, 10 ] (len = vocab_size + 1) +/// id=0 → data[0..5] = "hello" +/// id=1 → data[5..10] = "world" +/// ``` +/// +/// # Dense ids +/// Ids must form a range `0..N` but **gaps are allowed**: a missing id is +/// represented by an empty slice (`offsets[i] == offsets[i+1]`) and returns +/// `None` from [`get`]. An empty-string token and a gap are therefore +/// indistinguishable — avoid inserting empty tokens. +#[derive(Clone, Default, Debug)] +pub struct CompactVocab { + /// Concatenated UTF-8 bytes of every token, in ascending id order. + data: Vec, + /// `offsets[i]` = byte start of token `i`; `offsets[i+1]` = exclusive end. + /// `offsets.len() == max_id + 2` (or 0 for an empty vocab). + offsets: Vec, +} + +impl CompactVocab { + pub fn new() -> Self { + Self::default() + } + + /// Build from an iterator of `(token, id)` pairs. + /// + /// The pairs can arrive in any order and may contain gaps. Only one pass + /// over the iterator is made; the resulting buffer is allocated exactly + /// once. + pub fn from_vocab(iter: impl IntoIterator, u32)>) -> Self { + let mut sorted: Vec<(u32, String)> = iter + .into_iter() + .map(|(s, id)| (id, s.as_ref().to_owned())) + .collect(); + + if sorted.is_empty() { + return Self::default(); + } + sorted.sort_unstable_by_key(|(id, _)| *id); + + let max_id = sorted.last().unwrap().0 as usize; + let n = max_id + 1; // number of slots (including gaps) + + // Pre-calculate total data size to avoid reallocations. + let total_bytes: usize = sorted.iter().map(|(_, s)| s.len()).sum(); + let mut data = Vec::with_capacity(total_bytes); + let mut offsets = Vec::with_capacity(n + 1); + + let mut sorted_iter = sorted.into_iter().peekable(); + + for i in 0..n { + offsets.push(data.len() as u32); + if sorted_iter.peek().map(|(id, _)| *id as usize) == Some(i) { + let (_, token) = sorted_iter.next().unwrap(); + data.extend_from_slice(token.as_bytes()); + } + // gap → no bytes written → offsets[i] == offsets[i+1] (filled next iteration) + } + offsets.push(data.len() as u32); // sentinel + + Self { data, offsets } + } + + /// Return the token string for `id`, or `None` for an unknown / gap id. + /// + /// This is two array reads — no hash lookup. + #[inline] + pub fn get(&self, id: u32) -> Option<&str> { + let i = id as usize; + let (&start, &end) = (self.offsets.get(i)?, self.offsets.get(i + 1)?); + if start == end { + return None; // gap — id was never inserted + } + // SAFETY: `data` only ever receives bytes from valid `String` / `&str` values. + Some(unsafe { std::str::from_utf8_unchecked(&self.data[start as usize..end as usize]) }) + } + + /// Number of id slots (including gaps); equals `max_id + 1`. + pub fn len(&self) -> usize { + self.offsets.len().saturating_sub(1) + } + + pub fn is_empty(&self) -> bool { + self.offsets.len() <= 1 + } + + /// Iterate `(token, id)` pairs in ascending id order, skipping gaps. + pub fn iter(&self) -> impl Iterator + '_ { + let data = &self.data; + let offsets = &self.offsets; + let n = self.len() as u32; + (0..n).filter_map(move |id| { + let i = id as usize; + let (&start, &end) = (offsets.get(i)?, offsets.get(i + 1)?); + if start == end { + return None; + } + // SAFETY: data only receives bytes from valid String / &str values. + let s = unsafe { + std::str::from_utf8_unchecked(&data[start as usize..end as usize]) + }; + Some((s, id)) + }) + } +} + +impl PartialEq for CompactVocab { + fn eq(&self, other: &Self) -> bool { + self.data == other.data && self.offsets == other.offsets + } +} + +// --------------------------------------------------------------------------- +// Serde — JSON object {"token": id} in ascending id order, same format as +// OrderedVocabIter so existing tokenizer files remain compatible. +// --------------------------------------------------------------------------- + +use serde::{ + de::Deserializer, + ser::{SerializeMap, Serializer}, + Deserialize, Serialize, +}; + +impl Serialize for CompactVocab { + fn serialize(&self, serializer: S) -> Result { + let mut map = serializer.serialize_map(Some(self.len()))?; + for (token, id) in self.iter() { + map.serialize_entry(token, &id)?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for CompactVocab { + fn deserialize>(deserializer: D) -> Result { + // Deserialize as the same {"token": id} map that the BPE JSON uses. + let raw: std::collections::HashMap = + Deserialize::deserialize(deserializer)?; + Ok(Self::from_vocab(raw)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trip_dense() { + let pairs = vec![ + ("hello".to_string(), 0u32), + ("world".to_string(), 1), + ("foo".to_string(), 2), + ]; + let cv = CompactVocab::from_vocab(pairs); + + assert_eq!(cv.get(0), Some("hello")); + assert_eq!(cv.get(1), Some("world")); + assert_eq!(cv.get(2), Some("foo")); + assert_eq!(cv.get(3), None); + assert_eq!(cv.len(), 3); + } + + #[test] + fn round_trip_with_gaps() { + let pairs = vec![("a".to_string(), 0u32), ("c".to_string(), 2)]; + let cv = CompactVocab::from_vocab(pairs); + + assert_eq!(cv.get(0), Some("a")); + assert_eq!(cv.get(1), None); // gap + assert_eq!(cv.get(2), Some("c")); + assert_eq!(cv.len(), 3); // slots 0, 1, 2 + } + + #[test] + fn iter_skips_gaps() { + let pairs = vec![("a".to_string(), 0u32), ("c".to_string(), 2)]; + let cv = CompactVocab::from_vocab(pairs); + let collected: Vec<(&str, u32)> = cv.iter().collect(); + assert_eq!(collected, vec![("a", 0), ("c", 2)]); + } + + #[test] + fn serde_round_trip() { + let pairs = vec![ + ("hello".to_string(), 0u32), + ("world".to_string(), 1), + ]; + let cv = CompactVocab::from_vocab(pairs); + let json = serde_json::to_string(&cv).unwrap(); + assert_eq!(json, r#"{"hello":0,"world":1}"#); + let cv2: CompactVocab = serde_json::from_str(&json).unwrap(); + assert_eq!(cv, cv2); + } + + #[test] + fn unordered_input() { + // Input in reverse id order — should still reconstruct correctly. + let pairs = vec![("c".to_string(), 2u32), ("a".to_string(), 0), ("b".to_string(), 1)]; + let cv = CompactVocab::from_vocab(pairs); + assert_eq!(cv.get(0), Some("a")); + assert_eq!(cv.get(1), Some("b")); + assert_eq!(cv.get(2), Some("c")); + } + + #[test] + fn single_contiguous_allocation() { + let pairs = vec![("ab".to_string(), 0u32), ("cd".to_string(), 1)]; + let cv = CompactVocab::from_vocab(pairs); + // All 4 bytes live in one Vec. + assert_eq!(cv.data.len(), 4); + assert_eq!(&cv.data, b"abcd"); + } +} diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c9450b3222..ef4a2c4370 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,4 +1,6 @@ pub(crate) mod cache; +pub mod compact_vocab; +pub use compact_vocab::CompactVocab; #[cfg(feature = "http")] pub(crate) mod from_pretrained;