Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
use super::{trainer::BpeTrainer, Error, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
use crate::utils::compact_vocab::CompactVocab;
use crate::utils::iter::ResultShunt;
use ahash::AHashMap;
use serde_json::Value;
Expand All @@ -15,7 +16,6 @@ use std::{
};

pub type Vocab = AHashMap<String, u32>;
type VocabR = AHashMap<u32, String>;
pub type MergeMap = AHashMap<Pair, (u32, u32)>;
pub type Merges = Vec<(String, String)>;

Expand Down Expand Up @@ -154,12 +154,9 @@ impl BpeBuilder {
self.config.merges = m;
}

let vocab_r = self
.config
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
let vocab_r = CompactVocab::from_vocab(
self.config.vocab.iter().map(|(key, val)| (key.as_str(), *val)),
);
let cache = match self.config.cache_capacity {
0 => None,
capacity => Some(Cache::new(capacity)),
Expand Down Expand Up @@ -214,8 +211,8 @@ impl BpeBuilder {
pub struct BPE {
/// The vocabulary assigns a number to each token.
pub(crate) vocab: Vocab,
/// Reversed vocabulary, to rebuild sentences.
pub(crate) vocab_r: VocabR,
/// Compact id→token store: all strings in one allocation, O(1) indexed lookup.
pub(crate) vocab_r: CompactVocab,
/// Contains the mapping between Pairs and their (rank, new_id).
pub(crate) merges: MergeMap,
/// Contains the cache for optimizing the encoding step.
Expand Down Expand Up @@ -469,7 +466,10 @@ impl BPE {
fn word_to_tokens<'a>(&'a self, word: &'a Word) -> impl Iterator<Item = Token> + 'a {
word.get_chars_iter()
.zip(word.get_offsets_iter())
.map(move |(id, offsets)| Token::new(id, self.vocab_r[&id].clone(), offsets))
.map(move |(id, offsets)| {
let s = self.vocab_r.get(id).expect("id missing from vocab_r").to_owned();
Token::new(id, s, offsets)
})
}

fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
Expand Down Expand Up @@ -525,7 +525,7 @@ impl Model for BPE {
}

fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
self.vocab_r.get(id).map(str::to_owned)
}

fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
Expand All @@ -539,8 +539,8 @@ impl Model for BPE {
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter)?;
// CompactVocab serializes as the same {"token": id} JSON object.
let serialized = serde_json::to_string(&self.vocab_r)?;
vocab_file.write_all(serialized.as_bytes())?;

// Write merges.txt
Expand All @@ -564,7 +564,9 @@ impl Model for BPE {
&merges
.into_iter()
.flat_map(|(pair, _)| {
format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes()
let a = self.vocab_r.get(pair.0).expect("merge token missing from vocab");
let b = self.vocab_r.get(pair.1).expect("merge token missing from vocab");
format!("{a} {b}\n").into_bytes()
})
.collect::<Vec<_>>()[..],
)?;
Expand All @@ -583,18 +585,14 @@ mod tests {
use tempfile::NamedTempFile;

#[test]
fn test_ordered_vocab_iter() {
let vocab_r: VocabR = [
(0, "a".into()),
(1, "b".into()),
(2, "c".into()),
(3, "ab".into()),
]
.iter()
.cloned()
.collect();
let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
fn test_compact_vocab_serialization() {
let vocab_r = CompactVocab::from_vocab([
("a", 0u32),
("b", 1),
("c", 2),
("ab", 3),
]);
let serialized = serde_json::to_string(&vocab_r).unwrap();
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
}

Expand Down
12 changes: 8 additions & 4 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
use super::{convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
use ahash::AHashMap;
use serde::{
de::{Error, MapAccess, Visitor},
Expand Down Expand Up @@ -32,11 +32,15 @@ impl Serialize for BPE {
merges.sort_unstable_by_key(|k| *k.1);
let merges = merges
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
.map(|(pair, _)| {
let a = self.vocab_r.get(pair.0).expect("merge token missing").to_owned();
let b = self.vocab_r.get(pair.1).expect("merge token missing").to_owned();
(a, b)
})
.collect::<Vec<_>>();
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);

model.serialize_field("vocab", &ordered_vocab)?;
// CompactVocab serializes as {"token": id} in ascending id order.
model.serialize_field("vocab", &self.vocab_r)?;
model.serialize_field("merges", &merges)?;

model.end()
Expand Down
9 changes: 4 additions & 5 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::map_entry)]

use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::utils::compact_vocab::CompactVocab;
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle};
Expand Down Expand Up @@ -612,11 +613,9 @@ impl BpeTrainer {
// we have to look up the string in id_to_word because the key in word_to_id is a hash
.map(|(_key, val)| (id_to_word[val as usize].to_string(), val))
.collect();
model.vocab_r = model
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
model.vocab_r = CompactVocab::from_vocab(
model.vocab.iter().map(|(key, val)| (key.as_str(), *val)),
);
model.merges = merges
.into_iter()
.enumerate()
Expand Down
225 changes: 225 additions & 0 deletions tokenizers/src/utils/compact_vocab.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/// A compact, cache-friendly id→token store.
///
/// All token strings are concatenated into a single contiguous byte buffer and
/// indexed by a dense array of `u32` byte offsets. Reverse lookup (id → token)
/// is two array reads — no hash-table indirection, no per-string heap
/// allocation — and sequential id scans (serialization, iteration) stay in the
/// same cache lines.
///
/// # Layout
///
/// ```text
/// data: [ h e l l o w o r l d ]
/// offsets: [ 0, 5, 10 ] (len = vocab_size + 1)
/// id=0 → data[0..5] = "hello"
/// id=1 → data[5..10] = "world"
/// ```
///
/// # Dense ids
/// Ids must form a range `0..N` but **gaps are allowed**: a missing id is
/// represented by an empty slice (`offsets[i] == offsets[i+1]`) and returns
/// `None` from [`get`]. An empty-string token and a gap are therefore
/// indistinguishable — avoid inserting empty tokens.
#[derive(Clone, Default, Debug)]
pub struct CompactVocab {
/// Concatenated UTF-8 bytes of every token, in ascending id order.
data: Vec<u8>,
/// `offsets[i]` = byte start of token `i`; `offsets[i+1]` = exclusive end.
/// `offsets.len() == max_id + 2` (or 0 for an empty vocab).
offsets: Vec<u32>,
}

impl CompactVocab {
pub fn new() -> Self {
Self::default()
}

/// Build from an iterator of `(token, id)` pairs.
///
/// The pairs can arrive in any order and may contain gaps. Only one pass
/// over the iterator is made; the resulting buffer is allocated exactly
/// once.
pub fn from_vocab(iter: impl IntoIterator<Item = (impl AsRef<str>, u32)>) -> Self {
let mut sorted: Vec<(u32, String)> = iter
.into_iter()
.map(|(s, id)| (id, s.as_ref().to_owned()))
.collect();

if sorted.is_empty() {
return Self::default();
}
sorted.sort_unstable_by_key(|(id, _)| *id);

let max_id = sorted.last().unwrap().0 as usize;
let n = max_id + 1; // number of slots (including gaps)

// Pre-calculate total data size to avoid reallocations.
let total_bytes: usize = sorted.iter().map(|(_, s)| s.len()).sum();
let mut data = Vec::with_capacity(total_bytes);
let mut offsets = Vec::with_capacity(n + 1);

let mut sorted_iter = sorted.into_iter().peekable();

for i in 0..n {
offsets.push(data.len() as u32);
if sorted_iter.peek().map(|(id, _)| *id as usize) == Some(i) {
let (_, token) = sorted_iter.next().unwrap();
data.extend_from_slice(token.as_bytes());
}
// gap → no bytes written → offsets[i] == offsets[i+1] (filled next iteration)
}
offsets.push(data.len() as u32); // sentinel

Self { data, offsets }
}

/// Return the token string for `id`, or `None` for an unknown / gap id.
///
/// This is two array reads — no hash lookup.
#[inline]
pub fn get(&self, id: u32) -> Option<&str> {
let i = id as usize;
let (&start, &end) = (self.offsets.get(i)?, self.offsets.get(i + 1)?);
if start == end {
return None; // gap — id was never inserted
}
// SAFETY: `data` only ever receives bytes from valid `String` / `&str` values.
Some(unsafe { std::str::from_utf8_unchecked(&self.data[start as usize..end as usize]) })
}

/// Number of id slots (including gaps); equals `max_id + 1`.
pub fn len(&self) -> usize {
self.offsets.len().saturating_sub(1)
}

pub fn is_empty(&self) -> bool {
self.offsets.len() <= 1
}

/// Iterate `(token, id)` pairs in ascending id order, skipping gaps.
pub fn iter(&self) -> impl Iterator<Item = (&str, u32)> + '_ {
let data = &self.data;
let offsets = &self.offsets;
let n = self.len() as u32;
(0..n).filter_map(move |id| {
let i = id as usize;
let (&start, &end) = (offsets.get(i)?, offsets.get(i + 1)?);
if start == end {
return None;
}
// SAFETY: data only receives bytes from valid String / &str values.
let s = unsafe {
std::str::from_utf8_unchecked(&data[start as usize..end as usize])
};
Some((s, id))
})
}
}

impl PartialEq for CompactVocab {
fn eq(&self, other: &Self) -> bool {
self.data == other.data && self.offsets == other.offsets
}
}

// ---------------------------------------------------------------------------
// Serde — JSON object {"token": id} in ascending id order, same format as
// OrderedVocabIter so existing tokenizer files remain compatible.
// ---------------------------------------------------------------------------

use serde::{
de::Deserializer,
ser::{SerializeMap, Serializer},
Deserialize, Serialize,
};

impl Serialize for CompactVocab {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(Some(self.len()))?;
for (token, id) in self.iter() {
map.serialize_entry(token, &id)?;
}
map.end()
}
}

impl<'de> Deserialize<'de> for CompactVocab {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
// Deserialize as the same {"token": id} map that the BPE JSON uses.
let raw: std::collections::HashMap<String, u32> =
Deserialize::deserialize(deserializer)?;
Ok(Self::from_vocab(raw))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn round_trip_dense() {
let pairs = vec![
("hello".to_string(), 0u32),
("world".to_string(), 1),
("foo".to_string(), 2),
];
let cv = CompactVocab::from_vocab(pairs);

assert_eq!(cv.get(0), Some("hello"));
assert_eq!(cv.get(1), Some("world"));
assert_eq!(cv.get(2), Some("foo"));
assert_eq!(cv.get(3), None);
assert_eq!(cv.len(), 3);
}

#[test]
fn round_trip_with_gaps() {
let pairs = vec![("a".to_string(), 0u32), ("c".to_string(), 2)];
let cv = CompactVocab::from_vocab(pairs);

assert_eq!(cv.get(0), Some("a"));
assert_eq!(cv.get(1), None); // gap
assert_eq!(cv.get(2), Some("c"));
assert_eq!(cv.len(), 3); // slots 0, 1, 2
}

#[test]
fn iter_skips_gaps() {
let pairs = vec![("a".to_string(), 0u32), ("c".to_string(), 2)];
let cv = CompactVocab::from_vocab(pairs);
let collected: Vec<(&str, u32)> = cv.iter().collect();
assert_eq!(collected, vec![("a", 0), ("c", 2)]);
}

#[test]
fn serde_round_trip() {
let pairs = vec![
("hello".to_string(), 0u32),
("world".to_string(), 1),
];
let cv = CompactVocab::from_vocab(pairs);
let json = serde_json::to_string(&cv).unwrap();
assert_eq!(json, r#"{"hello":0,"world":1}"#);
let cv2: CompactVocab = serde_json::from_str(&json).unwrap();
assert_eq!(cv, cv2);
}

#[test]
fn unordered_input() {
// Input in reverse id order — should still reconstruct correctly.
let pairs = vec![("c".to_string(), 2u32), ("a".to_string(), 0), ("b".to_string(), 1)];
let cv = CompactVocab::from_vocab(pairs);
assert_eq!(cv.get(0), Some("a"));
assert_eq!(cv.get(1), Some("b"));
assert_eq!(cv.get(2), Some("c"));
}

#[test]
fn single_contiguous_allocation() {
let pairs = vec![("ab".to_string(), 0u32), ("cd".to_string(), 1)];
let cv = CompactVocab::from_vocab(pairs);
// All 4 bytes live in one Vec.
assert_eq!(cv.data.len(), 4);
assert_eq!(&cv.data, b"abcd");
}
}
2 changes: 2 additions & 0 deletions tokenizers/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub(crate) mod cache;
pub mod compact_vocab;
pub use compact_vocab::CompactVocab;
#[cfg(feature = "http")]
pub(crate) mod from_pretrained;

Expand Down
Loading