diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 6cd243298b..1502b70b04 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,6 +2,7 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; +use crate::utils::merge_table::MergeTable; use ahash::AHashMap; use serde_json::Value; use std::borrow::Cow; @@ -19,7 +20,7 @@ use std::{ pub type Vocab = AHashMap; type VocabR = AHashMap; -pub type MergeMap = AHashMap; +pub type MergeMap = MergeTable; /// Process-wide monotonic counter used to assign a unique generation id /// to every `BpeCache`, so per-instance thread-local caches never collide. @@ -248,8 +249,7 @@ impl BpeBuilder { } else { 0 }; - let mut buffer: Vec = vec![0; max_len]; - let merge_map: MergeMap = self + let merge_pairs = self .config .merges .into_iter() @@ -272,7 +272,8 @@ impl BpeBuilder { .ok_or_else(|| Error::MergeTokenOutOfVocabulary(new_token.to_owned()))?; Ok(((*a_id, *b_id), (i as u32, *new_id))) }) - .collect::>()?; + .collect::>>()?; + let merge_map = MergeTable::from_iter(merge_pairs); // merges.insert(pair, (rank as u32, *new_id)); @@ -644,12 +645,12 @@ impl Model for BPE { .iter() .collect(); let mut merges_file = File::create(&merges_path)?; - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, rank, _)| (pair, rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); merges_file.write_all(b"#version: 0.2\n")?; merges_file.write_all( &merges @@ -916,7 +917,7 @@ mod tests { let bpe = builder.build().unwrap(); // Check merges. - assert_eq!(bpe.merges.get(&(0, 1)).unwrap(), &(0u32, 3u32)); + assert_eq!(bpe.merges.get((0, 1)).unwrap(), (0u32, 3u32)); // Check vocab. assert_eq!(bpe.vocab.get("a").unwrap(), &0u32); diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cf549445..899bd97ac2 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -24,12 +24,12 @@ impl Serialize for BPE { model.serialize_field("ignore_merges", &self.ignore_merges)?; // Then the large ones - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, rank, _)| (pair, rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); let merges = merges .into_iter() .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index df68c655e9..a3bd36bf3c 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,6 +1,7 @@ #![allow(clippy::map_entry)] use super::{Pair, WithFirstLastIterator, Word, BPE}; +use crate::utils::merge_table::MergeTable; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle}; @@ -617,11 +618,12 @@ impl BpeTrainer { .iter() .map(|(key, val)| (*val, key.to_owned())) .collect(); - model.merges = merges - .into_iter() - .enumerate() - .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) - .collect(); + model.merges = MergeTable::from_iter( + merges + .into_iter() + .enumerate() + .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))), + ); model.continuing_subword_prefix = self.continuing_subword_prefix.clone(); model.end_of_word_suffix = self.end_of_word_suffix.clone(); @@ -677,8 +679,9 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { - use super::{BpeTrainer, Pair, BPE}; + use super::{BpeTrainer, BPE}; use ahash::AHashMap; + use crate::utils::merge_table::MergeTable; use compact_str::CompactString; #[test] @@ -744,14 +747,11 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: AHashMap = [ + let expected_merges = MergeTable::from_iter([ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' - ] - .iter() - .cloned() - .collect(); + ]); assert_eq!(model.merges, expected_merges); } #[test] diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 7bf2dee566..15911fd943 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,5 +1,5 @@ use super::Pair; -use ahash::AHashMap; +use crate::utils::merge_table::MergeTable; use dary_heap::QuaternaryHeap; use rand::{rng, Rng}; use std::cmp::Ordering; @@ -159,7 +159,7 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &AHashMap, dropout: Option) { + pub(super) fn merge_all(&mut self, merges: &MergeTable, dropout: Option) { let mut queue = QuaternaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); @@ -169,7 +169,7 @@ impl Word { .enumerate() .filter_map(|(index, window)| { let pair = (window[0].c, window[1].c); - merges.get(&pair).map(|m| Merge { + merges.get(pair).map(|m| Merge { pos: index, rank: m.0, new_id: m.1, @@ -198,8 +198,8 @@ impl Word { // Make sure we are not processing an expired queue entry let target_new_pair = (self.symbols[top.pos].c, right.c); if merges - .get(&target_new_pair) - .is_none_or(|(_, new_id)| *new_id != top.new_id) + .get(target_new_pair) + .is_none_or(|(_, new_id)| new_id != top.new_id) { continue; } @@ -220,11 +220,11 @@ impl Word { let prev = current.prev as usize; let prev_symbol = self.symbols[prev]; let new_pair = (prev_symbol.c, current.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { + if let Some((rank, new_id)) = merges.get(new_pair) { queue.push(Merge { pos: current.prev as usize, - rank: *rank, - new_id: *new_id, + rank, + new_id, }); } } @@ -234,11 +234,11 @@ impl Word { if next < self.symbols.len() { let next_symbol = self.symbols[next]; let new_pair = (current.c, next_symbol.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { + if let Some((rank, new_id)) = merges.get(new_pair) { queue.push(Merge { pos: top.pos, - rank: *rank, - new_id: *new_id, + rank, + new_id, }); } } diff --git a/tokenizers/src/utils/merge_table.rs b/tokenizers/src/utils/merge_table.rs new file mode 100644 index 0000000000..64fbf5fc76 --- /dev/null +++ b/tokenizers/src/utils/merge_table.rs @@ -0,0 +1,274 @@ +/// Open-addressing hash table tuned for BPE merge lookups. +/// +/// ## Layout +/// +/// Each slot holds a `(left: u32, right: u32, rank: u32, new_id: u32)` quad — +/// **16 bytes**. On a 64-byte cache line this packs exactly **4 slots**, so a +/// linear-probe sequence of ≤ 4 steps never leaves the cache line that was +/// fetched on the first miss. +/// +/// ## Why this beats a general-purpose hash map here +/// +/// * **Read-only after construction** — no deletion, no tombstones, no +/// re-hashing. The probe loop is two comparisons per slot and no branch on +/// a "tombstone" state. +/// * **Dense keys** — BPE pair IDs are small integers; the splitmix64 +/// finalizer distributes them uniformly without the overhead of a complex +/// hash function. +/// * **~60 % load factor** — keeps average probe length under 1.5, so most +/// lookups hit in the first cache line. +/// +/// ## Limitations +/// +/// Pair components must be < `u32::MAX` (valid for any model with < 4 billion +/// tokens). `u32::MAX` is used as the empty-slot sentinel. +use serde::{Deserialize, Serialize}; + +type Pair = (u32, u32); + +const EMPTY: u32 = u32::MAX; + +/// One slot in the table. 16 bytes → 4 per cache line. +#[derive(Clone, Copy)] +#[repr(C)] +struct Slot { + left: u32, + right: u32, + rank: u32, + new_id: u32, +} + +impl Slot { + #[inline] + fn is_empty(self) -> bool { + self.left == EMPTY + } +} + +/// Open-addressing hash table from `Pair → (rank, new_id)`. +#[derive(Clone)] +pub struct MergeTable { + slots: Box<[Slot]>, + /// `slots.len() - 1`; always a power-of-two mask. + mask: usize, +} + +impl std::fmt::Debug for MergeTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergeTable") + .field("len", &self.len()) + .field("capacity", &self.slots.len()) + .finish() + } +} + +// --------------------------------------------------------------------------- +// Hash +// --------------------------------------------------------------------------- + +/// Splitmix64 finalizer applied to the 64-bit encoding of a pair. +/// Excellent avalanche even for small (e.g. BPE vocab-sized) integers. +#[inline(always)] +fn hash_pair(a: u32, b: u32) -> usize { + let h = (a as u64) | ((b as u64) << 32); + let h = (h ^ (h >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); + let h = (h ^ (h >> 27)).wrapping_mul(0x94d049bb133111eb); + (h ^ (h >> 31)) as usize +} + +// --------------------------------------------------------------------------- +// Construction +// --------------------------------------------------------------------------- + +fn capacity_for(count: usize) -> usize { + // Target ~60 % load factor. Minimum 8 slots. + let min = (count * 8 / 5).max(8); + min.next_power_of_two() +} + +impl MergeTable { + /// Build from an iterator of `(pair, (rank, new_id))` entries. + pub fn from_iter(iter: impl IntoIterator) -> Self { + let entries: Vec<(Pair, (u32, u32))> = iter.into_iter().collect(); + let cap = capacity_for(entries.len()); + let mask = cap - 1; + + let mut slots = vec![ + Slot { + left: EMPTY, + right: EMPTY, + rank: 0, + new_id: 0, + }; + cap + ] + .into_boxed_slice(); + + for ((left, right), (rank, new_id)) in entries { + debug_assert!(left != EMPTY, "pair component must not equal u32::MAX"); + let mut i = hash_pair(left, right) & mask; + loop { + if slots[i].is_empty() { + slots[i] = Slot { + left, + right, + rank, + new_id, + }; + break; + } + i = (i + 1) & mask; + } + } + + Self { slots, mask } + } + + /// Look up a merge. Returns `Some((rank, new_id))` or `None`. + /// + /// Hot path: two u32 comparisons per slot, probe sequence stays within + /// the same 64-byte cache line for the first 4 steps. + #[inline] + pub fn get(&self, (left, right): Pair) -> Option<(u32, u32)> { + let mut i = hash_pair(left, right) & self.mask; + loop { + // SAFETY: i is always within bounds because i = (anything & mask) + // and mask = slots.len() - 1. + let s = unsafe { *self.slots.get_unchecked(i) }; + if s.left == left && s.right == right { + return Some((s.rank, s.new_id)); + } + if s.is_empty() { + return None; + } + i = (i + 1) & self.mask; + } + } + + /// Number of occupied entries. + pub fn len(&self) -> usize { + self.slots.iter().filter(|s| !s.is_empty()).count() + } + + pub fn is_empty(&self) -> bool { + self.slots.iter().all(|s| s.is_empty()) + } + + /// Iterate over all `(pair, rank, new_id)` entries in unspecified order. + pub fn iter(&self) -> impl Iterator + '_ { + self.slots.iter().filter(|s| !s.is_empty()).map(|s| { + ((s.left, s.right), s.rank, s.new_id) + }) + } +} + +// --------------------------------------------------------------------------- +// PartialEq — order-independent entry comparison +// --------------------------------------------------------------------------- + +impl PartialEq for MergeTable { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + self.iter() + .all(|(pair, rank, new_id)| other.get(pair) == Some((rank, new_id))) + } +} + +// --------------------------------------------------------------------------- +// Serde — serialise as a Vec for JSON round-trips +// --------------------------------------------------------------------------- + +impl Serialize for MergeTable { + fn serialize(&self, serializer: S) -> Result { + // Collect and sort by rank so the on-disk order is deterministic. + let mut entries: Vec<(Pair, u32, u32)> = self.iter().collect(); + entries.sort_unstable_by_key(|&(_, rank, _)| rank); + entries.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for MergeTable { + fn deserialize>(deserializer: D) -> Result { + let entries: Vec<(Pair, u32, u32)> = Deserialize::deserialize(deserializer)?; + Ok(Self::from_iter( + entries.into_iter().map(|(pair, rank, new_id)| (pair, (rank, new_id))), + )) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn small_table() -> MergeTable { + MergeTable::from_iter([ + ((0, 1), (0, 4)), + ((4, 2), (1, 5)), + ((5, 3), (2, 6)), + ]) + } + + #[test] + fn basic_lookup() { + let t = small_table(); + assert_eq!(t.get((0, 1)), Some((0, 4))); + assert_eq!(t.get((4, 2)), Some((1, 5))); + assert_eq!(t.get((5, 3)), Some((2, 6))); + assert_eq!(t.get((9, 9)), None); + } + + #[test] + fn len_and_iter() { + let t = small_table(); + assert_eq!(t.len(), 3); + let mut entries: Vec<_> = t.iter().collect(); + entries.sort_by_key(|&(_, rank, _)| rank); + assert_eq!(entries, [((0, 1), 0, 4), ((4, 2), 1, 5), ((5, 3), 2, 6)]); + } + + #[test] + fn partial_eq() { + let a = small_table(); + let b = MergeTable::from_iter([ + ((5, 3), (2, 6)), + ((0, 1), (0, 4)), + ((4, 2), (1, 5)), + ]); + assert_eq!(a, b); + } + + #[test] + fn absent_pair_is_none() { + let t = small_table(); + // Pairs where left/right are present individually but not together + assert_eq!(t.get((0, 2)), None); + assert_eq!(t.get((1, 0)), None); + } + + #[test] + fn large_table_no_collisions() { + // Build a 10 000 entry table and verify every entry round-trips. + let entries: Vec<(Pair, (u32, u32))> = (0u32..10_000) + .map(|i| ((i, i + 1), (i, i + 10_000))) + .collect(); + let t = MergeTable::from_iter(entries.iter().copied()); + for (pair, (rank, new_id)) in &entries { + assert_eq!(t.get(*pair), Some((*rank, *new_id))); + } + assert_eq!(t.len(), 10_000); + } + + #[test] + fn capacity_is_power_of_two() { + // Sanity-check that the table never exceeds ~60 % load factor. + let t = MergeTable::from_iter((0u32..1_000).map(|i| ((i, i + 1), (i, 0)))); + assert!(t.slots.len().is_power_of_two()); + assert!((t.len() as f64 / t.slots.len() as f64) < 0.65); + } +} diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c9450b3222..0d37455677 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,4 +1,6 @@ pub(crate) mod cache; +pub mod merge_table; +pub use merge_table::MergeTable; #[cfg(feature = "http")] pub(crate) mod from_pretrained;