Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
use crate::utils::iter::ResultShunt;
use crate::utils::merge_table::MergeTable;
use ahash::AHashMap;
use serde_json::Value;
use std::borrow::Cow;
Expand All @@ -19,7 +20,7 @@

pub type Vocab = AHashMap<String, u32>;
type VocabR = AHashMap<u32, String>;
pub type MergeMap = AHashMap<Pair, (u32, u32)>;
pub type MergeMap = MergeTable;

/// Process-wide monotonic counter used to assign a unique generation id
/// to every `BpeCache`, so per-instance thread-local caches never collide.
Expand Down Expand Up @@ -248,8 +249,7 @@
} else {
0
};
let mut buffer: Vec<u8> = vec![0; max_len];
let merge_map: MergeMap = self
let merge_pairs = self
.config
.merges
.into_iter()
Expand All @@ -261,18 +261,19 @@
let b_id = vocab
.get(&b)
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_owned()))?;
buffer[0..a.len()].copy_from_slice(a.as_bytes());

Check failure on line 264 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.14)

cannot find value `buffer` in this scope

Check failure on line 264 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

cannot find value `buffer` in this scope

Check failure on line 264 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.13)

cannot find value `buffer` in this scope

Check failure on line 264 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest, 3.14)

cannot find value `buffer` in this scope
let b_len = b.len() - prefix_len;
let merge_len = a.len() + b_len;
buffer[a.len()..merge_len].copy_from_slice(&b.as_bytes()[prefix_len..]);

Check failure on line 267 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.14)

cannot find value `buffer` in this scope

Check failure on line 267 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

cannot find value `buffer` in this scope

Check failure on line 267 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.13)

cannot find value `buffer` in this scope

Check failure on line 267 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest, 3.14)

cannot find value `buffer` in this scope
// SAFETY: buffer contains a concatenation of two valid UTF-8 strings, so it is itself valid UTF-8, even considering prefix_len
let new_token = unsafe { from_utf8_unchecked(&buffer[..merge_len]) };

Check failure on line 269 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.14)

cannot find value `buffer` in this scope

Check failure on line 269 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

cannot find value `buffer` in this scope

Check failure on line 269 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.13)

cannot find value `buffer` in this scope

Check failure on line 269 in tokenizers/src/models/bpe/model.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest, 3.14)

cannot find value `buffer` in this scope
let new_id = vocab
.get(new_token)
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(new_token.to_owned()))?;
Ok(((*a_id, *b_id), (i as u32, *new_id)))
})
.collect::<Result<MergeMap>>()?;
.collect::<Result<Vec<_>>>()?;
let merge_map = MergeTable::from_iter(merge_pairs);

// merges.insert(pair, (rank as u32, *new_id));

Expand Down Expand Up @@ -644,12 +645,12 @@
.iter()
.collect();
let mut merges_file = File::create(&merges_path)?;
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, rank, _)| (pair, rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
merges_file.write_all(b"#version: 0.2\n")?;
merges_file.write_all(
&merges
Expand Down Expand Up @@ -916,7 +917,7 @@
let bpe = builder.build().unwrap();

// Check merges.
assert_eq!(bpe.merges.get(&(0, 1)).unwrap(), &(0u32, 3u32));
assert_eq!(bpe.merges.get((0, 1)).unwrap(), (0u32, 3u32));

// Check vocab.
assert_eq!(bpe.vocab.get("a").unwrap(), &0u32);
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ impl Serialize for BPE {
model.serialize_field("ignore_merges", &self.ignore_merges)?;

// Then the large ones
let mut merges: Vec<(&Pair, &u32)> = self
let mut merges: Vec<(Pair, u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.map(|(pair, rank, _)| (pair, rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
merges.sort_unstable_by_key(|k| k.1);
let merges = merges
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
Expand Down
22 changes: 11 additions & 11 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::map_entry)]

use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::utils::merge_table::MergeTable;
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle};
Expand Down Expand Up @@ -617,11 +618,12 @@ impl BpeTrainer {
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
model.merges = merges
.into_iter()
.enumerate()
.map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id)))
.collect();
model.merges = MergeTable::from_iter(
merges
.into_iter()
.enumerate()
.map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))),
);

model.continuing_subword_prefix = self.continuing_subword_prefix.clone();
model.end_of_word_suffix = self.end_of_word_suffix.clone();
Expand Down Expand Up @@ -677,8 +679,9 @@ impl Trainer for BpeTrainer {

#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use super::{BpeTrainer, BPE};
use ahash::AHashMap;
use crate::utils::merge_table::MergeTable;
use compact_str::CompactString;

#[test]
Expand Down Expand Up @@ -744,14 +747,11 @@ mod tests {
// where 'rank' determines the order in which this merge will be applied during
// tokenization, and 'id' is the vocab id of the symbol resulting from merging
// the pair of symbols in the corresponding key.
let expected_merges: AHashMap<Pair, (u32, u32)> = [
let expected_merges = MergeTable::from_iter([
((17, 11), (0, 22)), // 'r' + 'e' -> 're'
((8, 22), (1, 23)), // 'a' + 're' -> 'are'
((13, 18), (2, 24)), // 'i' + 's' -> 'is'
]
.iter()
.cloned()
.collect();
]);
assert_eq!(model.merges, expected_merges);
}
#[test]
Expand Down
22 changes: 11 additions & 11 deletions tokenizers/src/models/bpe/word.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Pair;
use ahash::AHashMap;
use crate::utils::merge_table::MergeTable;
use dary_heap::QuaternaryHeap;
use rand::{rng, Rng};
use std::cmp::Ordering;
Expand Down Expand Up @@ -159,7 +159,7 @@ impl Word {
changes
}

pub(super) fn merge_all(&mut self, merges: &AHashMap<Pair, (u32, u32)>, dropout: Option<f32>) {
pub(super) fn merge_all(&mut self, merges: &MergeTable, dropout: Option<f32>) {
let mut queue = QuaternaryHeap::with_capacity(self.symbols.len());
let mut skip = Vec::with_capacity(queue.len());

Expand All @@ -169,7 +169,7 @@ impl Word {
.enumerate()
.filter_map(|(index, window)| {
let pair = (window[0].c, window[1].c);
merges.get(&pair).map(|m| Merge {
merges.get(pair).map(|m| Merge {
pos: index,
rank: m.0,
new_id: m.1,
Expand Down Expand Up @@ -198,8 +198,8 @@ impl Word {
// Make sure we are not processing an expired queue entry
let target_new_pair = (self.symbols[top.pos].c, right.c);
if merges
.get(&target_new_pair)
.is_none_or(|(_, new_id)| *new_id != top.new_id)
.get(target_new_pair)
.is_none_or(|(_, new_id)| new_id != top.new_id)
{
continue;
}
Expand All @@ -220,11 +220,11 @@ impl Word {
let prev = current.prev as usize;
let prev_symbol = self.symbols[prev];
let new_pair = (prev_symbol.c, current.c);
if let Some((rank, new_id)) = merges.get(&new_pair) {
if let Some((rank, new_id)) = merges.get(new_pair) {
queue.push(Merge {
pos: current.prev as usize,
rank: *rank,
new_id: *new_id,
rank,
new_id,
});
}
}
Expand All @@ -234,11 +234,11 @@ impl Word {
if next < self.symbols.len() {
let next_symbol = self.symbols[next];
let new_pair = (current.c, next_symbol.c);
if let Some((rank, new_id)) = merges.get(&new_pair) {
if let Some((rank, new_id)) = merges.get(new_pair) {
queue.push(Merge {
pos: top.pos,
rank: *rank,
new_id: *new_id,
rank,
new_id,
});
}
}
Expand Down
Loading
Loading