Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ monostate = "0.1.12"
ahash = { version = "0.8.11", features = ["serde"] }
dary_heap = { version = "0.3.6", features = ["serde"] }
compact_str = { version = "0.9", features = ["serde"] }
logos = "0.14"

[features]
default = ["progressbar", "onig", "esaxx_fast"]
default = ["progressbar", "onig", "esaxx_fast", "logos-pretok"]
logos-pretok = []
esaxx_fast = ["esaxx-rs/cpp"]
progressbar = ["indicatif"]
http = ["hf-hub"]
Expand Down
168 changes: 167 additions & 1 deletion tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use ahash::{AHashMap, AHashSet};
use std::sync::LazyLock;

#[cfg(not(feature = "logos-pretok"))]
use crate::utils::SysRegex;
#[cfg(feature = "logos-pretok")]
use logos::Logos;
use serde::{Deserialize, Serialize};

#[cfg(feature = "logos-pretok")]
use crate::tokenizer::pattern::Pattern;
#[cfg(feature = "logos-pretok")]
use crate::tokenizer::Offsets;
use crate::tokenizer::{
Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
SplitDelimiterBehavior,
Expand Down Expand Up @@ -40,10 +47,159 @@ pub(crate) fn bytes_char() -> AHashMap<u8, char> {

/// Regex that matches exactly one token.
/// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L98
#[cfg(not(feature = "logos-pretok"))]
static RE: LazyLock<SysRegex> = LazyLock::new(|| {
SysRegex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")
.unwrap()
});

/// Compile-time FSM equivalent of the GPT-2 split regex above. Variants are
/// declared in the same priority order as the regex alternation (logos uses
/// source order as tiebreaker + longest-match).
///
/// Note: logos cannot express `\s+(?!\S)` directly (no lookahead). Instead
/// we match the full `\s+` greedily and then replay the lookahead semantics
/// as a post-processing pass in `LogosByteLevel::find_matches` below. The
/// effect of `\s+(?!\S)` in the original pattern is: when a whitespace run
/// of length ≥ 2 is followed by a non-whitespace char, the match backtracks
/// by one char so the trailing space is left as a ` ?` prefix for the next
/// Letter/Number/Other token. That backtracking is what we emulate.
#[cfg(feature = "logos-pretok")]
#[derive(Logos, Debug, Clone, Copy, PartialEq, Eq)]
enum BlTok {
#[token("'s")]
#[token("'t")]
#[token("'re")]
#[token("'ve")]
#[token("'m")]
#[token("'ll")]
#[token("'d")]
Contraction,

#[regex(r" ?\p{L}+")]
Letters,

#[regex(r" ?\p{N}+")]
Numbers,

#[regex(r" ?[^\s\p{L}\p{N}]+")]
Other,

#[regex(r"\s+")]
Whitespace,
}

/// Hidden re-export so equivalence tests can drive the logos `Pattern` impl
/// directly side-by-side with the legacy `SysRegex` one. Not a stable API.
#[cfg(feature = "logos-pretok")]
#[doc(hidden)]
pub struct LogosByteLevel;

#[cfg(feature = "logos-pretok")]
impl Pattern for &LogosByteLevel {
fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
if inside.is_empty() {
return Ok(vec![((0, 0), false)]);
}
let mut tokens: Vec<(Option<BlTok>, usize, usize)> = Vec::with_capacity(inside.len());
let mut lex = BlTok::lexer(inside);
while let Some(result) = lex.next() {
let span = lex.span();
tokens.push((result.ok(), span.start, span.end));
}

// Replay `\s+(?!\S)` lookahead: for each Whitespace span of char
// length ≥ 2 directly followed by a content token, shrink the ws
// span by one char and give that char to the next span as a
// leading-space prefix. Mirrors onig's backtrack exactly.
//
// Contractions are a special case. Logos (longest-match) picks a
// 2–3 char contraction literal (e.g. `'t`) over the shorter
// ` ?[^\s\p{L}\p{N}]+` match (just `'`). Legacy (leftmost-first)
// does the opposite: after backtracking the ws, its `Other`
// alternative sits earlier in the alternation and consumes ` '`
// as a 2-char span, leaving the remaining `tis` to be matched by
// `\p{L}+`. To mirror this we split the contraction into
// `Other(')` + `Letters(rest)`, extend Other backward to claim
// the freed ws char, and merge `Letters(rest)` with a following
// Letters span if contiguous and not starting with whitespace.
let mut i = 0;
while i < tokens.len().saturating_sub(1) {
if !matches!(tokens[i].0, Some(BlTok::Whitespace)) {
i += 1;
continue;
}
let (start, end) = (tokens[i].1, tokens[i].2);
let ws_slice = &inside[start..end];
if ws_slice.chars().count() < 2 {
i += 1;
continue;
}
let last_char_off = ws_slice
.char_indices()
.last()
.map(|(b, _)| b)
.unwrap_or(0);
let new_ws_end = start + last_char_off;

match tokens[i + 1].0 {
Some(BlTok::Letters) | Some(BlTok::Numbers) | Some(BlTok::Other) => {
tokens[i].2 = new_ws_end;
tokens[i + 1].1 = new_ws_end;
i += 1;
}
Some(BlTok::Contraction) => {
let (_, cstart, cend) = tokens[i + 1];
// `'` is ASCII → 1 byte
let quote_end = cstart + 1;
tokens[i].2 = new_ws_end;
tokens[i + 1] = (Some(BlTok::Other), new_ws_end, quote_end);

let letters_seg = (Some(BlTok::Letters), quote_end, cend);
let mut merged = false;
if let Some(next_next) = tokens.get(i + 2) {
if matches!(next_next.0, Some(BlTok::Letters)) && next_next.1 == cend {
let first_is_ws = inside[next_next.1..next_next.2]
.chars()
.next()
.map(|c| c.is_whitespace())
.unwrap_or(false);
if !first_is_ws {
tokens[i + 2].1 = quote_end;
merged = true;
}
}
}
if !merged && letters_seg.1 < letters_seg.2 {
tokens.insert(i + 2, letters_seg);
}
i += 2;
}
_ => {
i += 1;
}
}
}

let mut prev = 0;
let mut splits = Vec::with_capacity(tokens.len());
for (_variant, start, end) in tokens {
if start == end {
continue;
}
if prev != start {
splits.push(((prev, start), false));
}
splits.push(((start, end), true));
prev = end;
}
if prev != inside.len() {
splits.push(((prev, inside.len()), false));
}
Ok(splits)
}
}

static BYTES_CHAR: LazyLock<AHashMap<u8, char>> = LazyLock::new(bytes_char);
static CHAR_BYTES: LazyLock<AHashMap<char, u8>> =
LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect());
Expand Down Expand Up @@ -118,13 +274,23 @@ impl ByteLevel {
// TODO: Give the ability to modify this regex
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
#[cfg(not(feature = "logos-pretok"))]
let re_ref: &SysRegex = &RE;
#[cfg(feature = "logos-pretok")]
let logos_pat = LogosByteLevel;
pretokenized.split(|_, mut normalized| {
if self.add_prefix_space && !normalized.get().starts_with(' ') {
normalized.prepend(" ");
}
if self.use_regex {
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
#[cfg(feature = "logos-pretok")]
{
normalized.split(&logos_pat, SplitDelimiterBehavior::Isolated)
}
#[cfg(not(feature = "logos-pretok"))]
{
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
}
} else {
Ok(vec![normalized])
}
Expand Down
Loading
Loading