diff --git a/.github/workflows/python-release.yml b/.github/workflows/python-release.yml index 0547ef72bb..9979292d50 100644 --- a/.github/workflows/python-release.yml +++ b/.github/workflows/python-release.yml @@ -138,6 +138,31 @@ jobs: - run: twine check --strict dist/* working-directory: ./bindings/python + - name: Report wheel sizes + working-directory: ./bindings/python + run: | + echo "## 🐍 Python Wheel Size β€” ${{ matrix.os }} ${{ matrix.target }}" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Wheel (.whl) = compressed archive downloaded from PyPI." >> $GITHUB_STEP_SUMMARY + echo "Installed .so/.pyd = actual shared library loaded at runtime." >> $GITHUB_STEP_SUMMARY + echo "The installed size is what matters for on-device deployment." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Wheel | Wheel size | Installed .so/.pyd |" >> $GITHUB_STEP_SUMMARY + echo "|---|---|---|" >> $GITHUB_STEP_SUMMARY + EXTRACT_DIR=$(mktemp -d) + for f in dist/*.whl; do + WHL_SIZE=$(du -h "$f" | cut -f1) + NAME=$(basename "$f") + rm -rf "$EXTRACT_DIR"/* + (cd "$EXTRACT_DIR" && unzip -q "$(realpath -- "$OLDPWD/$f" 2>/dev/null || echo "$OLDPWD/$f")" 2>/dev/null) \ + || unzip -q -o "$f" -d "$EXTRACT_DIR" 2>/dev/null || true + SO_SIZE=$(find "$EXTRACT_DIR" \( -name '*.so' -o -name '*.pyd' -o -name '*.dylib' \) -exec du -h {} \; | head -1 | cut -f1) + [ -z "$SO_SIZE" ] && SO_SIZE="n/a" + echo "| \`${NAME}\` | ${WHL_SIZE} | ${SO_SIZE} |" >> $GITHUB_STEP_SUMMARY + done + rm -rf "$EXTRACT_DIR" + echo "" >> $GITHUB_STEP_SUMMARY + - uses: actions/upload-artifact@v4 with: name: pypi_files-${{ matrix.os }}-${{ matrix.target }}-${{ matrix.manylinux }} @@ -180,6 +205,42 @@ jobs: with: path: ./bindings/python/dist merge-multiple: true + + - name: Wheel size summary + working-directory: ./bindings/python + run: | + echo "## πŸ“¦ All Python Wheel Sizes" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Wheel | Wheel size | Installed .so/.pyd |" >> $GITHUB_STEP_SUMMARY + echo "|---|---|---|" >> $GITHUB_STEP_SUMMARY + TOTAL_WHL=0 + TOTAL_SO=0 + EXTRACT_DIR=$(mktemp -d) + for f in dist/*.whl; do + WHL_BYTES=$(stat --format=%s "$f" 2>/dev/null || stat -f%z "$f") + WHL_SIZE=$(du -h "$f" | cut -f1) + NAME=$(basename "$f") + rm -rf "$EXTRACT_DIR"/* + unzip -q -o "$f" -d "$EXTRACT_DIR" 2>/dev/null || true + SO_FILE=$(find "$EXTRACT_DIR" \( -name '*.so' -o -name '*.pyd' -o -name '*.dylib' \) | head -1) + if [ -n "$SO_FILE" ]; then + SO_BYTES=$(stat --format=%s "$SO_FILE" 2>/dev/null || stat -f%z "$SO_FILE") + SO_SIZE=$(du -h "$SO_FILE" | cut -f1) + TOTAL_SO=$((TOTAL_SO + SO_BYTES)) + else + SO_SIZE="n/a" + fi + echo "| \`${NAME}\` | ${WHL_SIZE} | ${SO_SIZE} |" >> $GITHUB_STEP_SUMMARY + TOTAL_WHL=$((TOTAL_WHL + WHL_BYTES)) + done + rm -rf "$EXTRACT_DIR" + echo "" >> $GITHUB_STEP_SUMMARY + TOTAL_WHL_MB=$(echo "scale=2; $TOTAL_WHL / 1048576" | bc) + TOTAL_SO_MB=$(echo "scale=2; $TOTAL_SO / 1048576" | bc) + WHL_COUNT=$(ls dist/*.whl 2>/dev/null | wc -l | tr -d ' ') + echo "**Total**: ${WHL_COUNT} wheels | wheel: ${TOTAL_WHL_MB} MB | installed: ${TOTAL_SO_MB} MB" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + # Temporary deactivation while testing abi3 CI # - name: Upload to PyPi # working-directory: ./bindings/python diff --git a/.github/workflows/rust-release.yml b/.github/workflows/rust-release.yml index 05a75f072c..001c579a8e 100644 --- a/.github/workflows/rust-release.yml +++ b/.github/workflows/rust-release.yml @@ -24,6 +24,85 @@ jobs: path: ~/.cargo/registry key: ubuntu-latest-cargo-registry-${{ hashFiles('**/Cargo.toml') }} + - name: Measure crate size + working-directory: ./tokenizers + run: | + echo "## πŸ“¦ Crate Size Report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Packed crate size (what gets uploaded to crates.io) + cargo package --list --allow-dirty > /tmp/crate_files.txt + CRATE_FILE_COUNT=$(wc -l < /tmp/crate_files.txt | tr -d ' ') + PACKED_SIZE=$(cargo package --allow-dirty 2>&1 | grep -oP 'Packaged \d+ files?, \K[\d.]+ \w+' || echo "unknown") + echo "### Packed crate (crates.io)" >> $GITHUB_STEP_SUMMARY + echo "- **Size**: ${PACKED_SIZE}" >> $GITHUB_STEP_SUMMARY + echo "- **Files**: ${CRATE_FILE_COUNT}" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Linked shared library size for various feature combinations. + # This is the actual on-device size β€” what ships to users β€” NOT the + # .rlib, which contains unused code the final linker strips. + # We build a minimal cdylib that uses the Tokenizer API and measure it. + TEST_DIR=$(mktemp -d) + TOK_PATH="$(pwd)" + mkdir -p "$TEST_DIR/src" + cat > "$TEST_DIR/src/lib.rs" << 'RS' + use tokenizers::Tokenizer; + #[no_mangle] + pub extern "C" fn tokenize(path: *const u8, len: usize, input: *const u8, input_len: usize) -> usize { + let path = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(path, len)) }; + let input = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(input, input_len)) }; + let tok = Tokenizer::from_file(path).unwrap(); + tok.encode(input, false).unwrap().get_ids().len() + } + RS + + measure() { + local LABEL="$1" + local FEATURES="$2" + cat > "$TEST_DIR/Cargo.toml" << TOML + [package] + name = "size-test" + version = "0.1.0" + edition = "2021" + [lib] + crate-type = ["cdylib"] + [dependencies] + tokenizers = { path = "$TOK_PATH", ${FEATURES} } + [profile.release] + lto = "fat" + opt-level = "s" + strip = true + codegen-units = 1 + panic = "abort" + TOML + (cd "$TEST_DIR" && cargo build --release >/dev/null 2>&1) + local LIB=$(find "$TEST_DIR/target/release" -maxdepth 1 \( -name '*.so' -o -name '*.dylib' -o -name '*.dll' \) | head -1) + if [ -n "$LIB" ]; then + local BYTES=$(stat --format=%s "$LIB" 2>/dev/null || stat -f%z "$LIB") + local KB=$((BYTES / 1024)) + echo "| ${LABEL} | ${KB} KB |" >> $GITHUB_STEP_SUMMARY + else + echo "| ${LABEL} | build failed |" >> $GITHUB_STEP_SUMMARY + fi + } + + echo "### Linked shared library size (stripped cdylib, LTO fat, opt-level=s, panic=abort)" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "This is the actual on-device size β€” what ships to end users." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Feature set | Size |" >> $GITHUB_STEP_SUMMARY + echo "|---|---|" >> $GITHUB_STEP_SUMMARY + + measure "default (all features)" '' + measure "inference (onig + unicode-norm + spm)" 'default-features = false, features = ["inference"]' + measure "minimal onig-only" 'default-features = false, features = ["onig"]' + measure "no-default + training" 'default-features = false, features = ["onig", "training"]' + measure "no-default + parallel" 'default-features = false, features = ["onig", "parallel"]' + + rm -rf "$TEST_DIR" + echo "" >> $GITHUB_STEP_SUMMARY + - name: Publish package rust working-directory: ./tokenizers if: ${{ !contains(github.ref, 'rc') }} diff --git a/bindings/node/Cargo.toml b/bindings/node/Cargo.toml index 42d3c97148..d834aa7617 100644 --- a/bindings/node/Cargo.toml +++ b/bindings/node/Cargo.toml @@ -14,7 +14,6 @@ napi = "2" napi-derive = "2" serde = { version = "1.0.163", features = ["derive"] } tokenizers = { path = "../../tokenizers/" } -ahash = { version = "0.8.11", features = ["serde"] } [build-dependencies] napi-build = "2" diff --git a/bindings/node/src/models.rs b/bindings/node/src/models.rs index 9ee7f60f7d..8beebb7f8a 100644 --- a/bindings/node/src/models.rs +++ b/bindings/node/src/models.rs @@ -1,7 +1,6 @@ use crate::arc_rwlock_serde; use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask}; use crate::trainers::Trainer; -use ahash::AHashMap; use napi::bindgen_prelude::*; use napi_derive::napi; use serde::{Deserialize, Serialize}; @@ -12,6 +11,7 @@ use tokenizers as tk; use tokenizers::models::bpe::{BpeBuilder, Merges}; use tokenizers::models::wordlevel::WordLevelBuilder; use tokenizers::models::wordpiece::WordPieceBuilder; +use tokenizers::utils::AHashMap; #[napi] #[derive(Clone, Serialize, Deserialize)] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 52da7498c7..d15f68a985 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -21,7 +21,6 @@ once_cell = "1.19.0" numpy = "0.28" ndarray = "0.16" itertools = "0.14" -ahash = { version = "0.8.11", features = ["serde"] } pyo3-ffi = { version = "0.28" } [dependencies.tokenizers] @@ -34,3 +33,7 @@ pyo3 = { version = "0.28", features = ["auto-initialize", "experimental-inspect" [features] default = ["ext-module"] ext-module = ["pyo3/extension-module"] + +[profile.release] +strip = true +lto = "fat" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index d2f7bf7df1..a445f74c7d 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -4,7 +4,6 @@ use std::sync::{Arc, RwLock}; use crate::token::PyToken; use crate::trainers::PyTrainer; -use ahash::AHashMap; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -14,6 +13,7 @@ use tk::models::unigram::Unigram; use tk::models::wordlevel::WordLevel; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; use tk::models::ModelWrapper; +use tk::utils::AHashMap; use tk::{Model, Token}; use tokenizers as tk; diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 0e937f3cc5..ac7fa9d4c0 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -67,22 +67,19 @@ name = "ci_benchmark" harness = false [dependencies] -rand = "0.9" +rand = { version = "0.9", optional = true } onig = { version = "6.5.1", default-features = false, optional = true } -regex = "1.10" -regex-syntax = "0.8" -rayon = "1.10" -rayon-cond = "0.4" +regex = { version = "1.10", default-features = false, features = ["std", "perf", "unicode-perl"], optional = true } +rayon = { version = "1.10", optional = true } +rayon-cond = { version = "0.4", optional = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -unicode-normalization-alignments = "0.1" +unicode-normalization-alignments = { version = "0.1", optional = true } unicode_categories = "0.1" -unicode-segmentation = "1.11" +unicode-segmentation = { version = "1.11", optional = true } indicatif = { version = "0.18", optional = true } -itertools = "0.14" log = "0.4" -derive_builder = "0.20" -spm_precompiled = "0.1.3" +spm_precompiled = { version = "0.1.3", optional = true } hf-hub = { version = "0.4.1", features = ["ureq"], default-features = false, optional = true } daachorse = "1.0.1" paste = "1.0.14" @@ -90,17 +87,21 @@ macro_rules_attribute = "0.2.0" thiserror = "2" fancy-regex = { version = "0.17", optional = true } getrandom = { version = "0.3" } -esaxx-rs = { version = "0.1.10", default-features = false, features = [] } -monostate = "0.1.12" -ahash = { version = "0.8.11", features = ["serde"] } -dary_heap = { version = "0.3.6", features = ["serde"] } -compact_str = { version = "0.9", features = ["serde"] } +esaxx-rs = { version = "0.1.10", default-features = false, features = [], optional = true } +foldhash = "0.2" +dary_heap = "0.3.6" +compact_str = { version = "0.9", features = ["serde"], optional = true } [features] -default = ["progressbar", "onig", "esaxx_fast"] -esaxx_fast = ["esaxx-rs/cpp"] +default = ["progressbar", "onig", "esaxx_fast", "spm", "training", "parallel", "unicode-normalization", "regex"] +unicode-normalization = ["dep:unicode-normalization-alignments"] +parallel = ["dep:rayon", "dep:rayon-cond"] +training = ["dep:rand", "dep:esaxx-rs", "dep:compact_str"] +spm = ["dep:spm_precompiled", "dep:unicode-segmentation"] +esaxx_fast = ["dep:esaxx-rs", "esaxx-rs/cpp"] progressbar = ["indicatif"] http = ["hf-hub"] +inference = ["onig", "unicode-normalization", "spm"] unstable_wasm = ["fancy-regex", "getrandom/wasm_js"] rustls-tls = ["hf-hub?/rustls-tls"] @@ -114,6 +115,21 @@ tracing-subscriber = "0.3.18" [profile.release] lto = "fat" +# Use this profile for minimal binary size (e.g. on-device deployment). +# Pair with the `inference` feature for all inference capabilities without training/parallel: +# cargo build --profile release-small --no-default-features --features inference +# For even smaller builds (nightly only): +# RUSTFLAGS="-Zlocation-detail=none -Zfmt-debug=none" cargo +nightly build \ +# -Z build-std=std,panic_abort -Z build-std-features="optimize_for_size" \ +# --target --profile release-small \ +# --no-default-features --features inference +[profile.release-small] +inherits = "release" +opt-level = "s" +strip = true +panic = "abort" +codegen-units = 1 + [profile.profiling] inherits = "release" debug = true diff --git a/tokenizers/README.md b/tokenizers/README.md index 173e0bc065..7f7ab46963 100644 --- a/tokenizers/README.md +++ b/tokenizers/README.md @@ -135,9 +135,134 @@ fn main() -> Result<()> { ## Features -- **progressbar**: The progress bar visualization is enabled by default. It might be disabled if - compilation for certain targets is not supported by the [termios](https://crates.io/crates/termios) - dependency of the [indicatif](https://crates.io/crates/indicatif) progress bar. +All features are **enabled by default** for backward compatibility. Disable them for on-device/embedded use. -- **http**: This feature enables downloading the tokenizer via HTTP. It is disabled by default. - With this feature enabled, `Tokenizer::from_pretrained` becomes accessible. +| Feature | Default | Description | Deps saved | +|---------|---------|-------------|------------| +| `training` | on | Tokenizer training (trainers, `train()` method) | rand, esaxx-rs, compact_str | +| `parallel` | on | Multi-threaded encoding via rayon | rayon, rayon-cond, crossbeam | +| `spm` | on | SentencePiece precompiled normalizer (T5, mBART) | spm_precompiled, nom, unicode-segmentation | +| `unicode-normalization` | on | NFC/NFD/NFKC/NFKD normalizers | unicode-normalization-alignments | +| `progressbar` | on | Progress bars during training | indicatif | +| `onig` | on | Oniguruma regex engine (C binding) | onig, onig_sys | +| `http` | off | Download tokenizers from Hugging Face Hub | hf-hub, ureq | +| `unstable_wasm` | off | WASM target support (uses fancy-regex) | fancy-regex | + +### On-device / embedded configuration + +```toml +# Minimal inference-only (with Oniguruma regex): +tokenizers = { version = "0.22", default-features = false, features = ["onig"] } + +# WASM (pure Rust, no C dependencies): +tokenizers = { version = "0.22", default-features = false, features = ["unstable_wasm"] } +``` + +## Bundle size + +The deployed library size depends on how you link it. Here are measured sizes on macOS arm64: + +| Configuration | .dylib (shared) | .a (static) | After final link | +|---------------|----------------|-------------|-----------------| +| Default (all features) | 2.5 MB | 9.2 MB | ~2.5 MB | +| Inference-only (`onig`) | 2.0 MB | 8.0 MB | ~2.0 MB | + +> **Note**: `.a` (static archive) files contain all object code including unused functions. +> The linker strips dead code at final link time, so the actual contribution to your app +> binary is close to the `.dylib` size. The `.a` size is NOT what ships to users. + +### Comparison with Meta pytorch/tokenizers (C++) + +| | Meta (C++) | HuggingFace (Rust) | +|---|---|---| +| Stripped binary (all tokenizer types) | **0.8 MB** | **2.0 MB** | +| Static .a (pre-link, all deps) | 5.5 MB | 8.0 MB | +| Features | SP, Tiktoken, Llama2c | BPE, WordPiece, Unigram, WordLevel + normalizers, pre-tokenizers, decoders, added vocab | + +HuggingFace is ~2.5x larger because it includes full `tokenizer.json` parsing (serde), Unicode-aware +regex, all normalizer/pre-tokenizer/decoder types, and added vocabulary matching β€” features Meta's +library doesn't have. + +### How to measure bundle size + +**1. Measure the linked shared library (what ships to users):** + +```bash +# Create a test crate that links tokenizers as a cdylib +cargo new --lib measure-size && cd measure-size +cat >> Cargo.toml << 'EOF' +[lib] +crate-type = ["cdylib"] + +[dependencies] +tokenizers = { path = "../tokenizers", default-features = false, features = ["onig"] } + +[profile.release] +lto = "fat" +opt-level = "s" +strip = true +EOF + +echo 'use tokenizers::Tokenizer; +#[no_mangle] +pub extern "C" fn tokenize() { let _ = Tokenizer::from_file("t.json"); }' > src/lib.rs + +cargo build --release +ls -lh target/release/*.dylib # macOS +ls -lh target/release/*.so # Linux +``` + +**2. Measure per-crate contribution with cargo-bloat:** + +```bash +cargo install cargo-bloat +cargo bloat --release --crates -n 30 +``` + +**3. Measure dependency rlib sizes (compile-time cost):** + +```bash +# Total rlib for runtime deps only +cargo tree --edges=normal --prefix none -f '{p}' | awk '{print $1}' | sort -u | sed 's/-/_/g' > /tmp/deps.txt + +for f in target/release/deps/*.rlib; do + sz=$(stat -f%z "$f" 2>/dev/null || stat -c%s "$f" 2>/dev/null) + name=$(basename "$f" | sed 's/-[a-f0-9]*\.rlib//' | sed 's/^lib//') + echo "$sz $name" +done | sort -t' ' -k2 | awk '!seen[$2]++ {print}' | sort -k2 > /tmp/rlibs.txt + +join -1 2 -2 1 /tmp/rlibs.txt /tmp/deps.txt | awk '{ + total+=$2 + printf "%8.1f KB %s\n", $2/1024, $1 +} END { + printf "\nTOTAL: %.1f MB\n", total/1048576 +}' | sort -rn +``` + +**4. Track size in CI (regression test):** + +```bash +#!/bin/bash +# scripts/check-bundle-size.sh +set -e + +MAX_DYLIB_KB=2500 # 2.5 MB threshold + +cargo build --release --no-default-features --features "onig" \ + --target-dir /tmp/size-check + +SIZE=$(stat -f%z /tmp/size-check/release/libtokenizers.rlib 2>/dev/null \ + || stat -c%s /tmp/size-check/release/libtokenizers.rlib) +SIZE_KB=$((SIZE / 1024)) + +echo "libtokenizers.rlib: ${SIZE_KB} KB" + +# For the actual linked size, build a cdylib test crate +# (see step 1 above) and check the .dylib/.so size + +if [ "$SIZE_KB" -gt "$MAX_DYLIB_KB" ]; then + echo "FAIL: bundle size ${SIZE_KB} KB exceeds threshold ${MAX_DYLIB_KB} KB" + exit 1 +fi +echo "PASS: bundle size OK" +``` diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 57b7b63cd7..32d03d0729 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,23 +1,22 @@ use crate::tokenizer::{Decoder, Result}; -use monostate::MustBe; -use serde::{Deserialize, Serialize}; +impl_serde_type! { + #[derive(Clone, Debug)] + /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` + /// to pure bytes, and attempts to make them into a string. If the tokens + /// cannot be decoded you will get οΏ½ instead for each inconvertible byte token + pub struct ByteFallback; +} -#[derive(Deserialize, Clone, Debug, Serialize, Default)] -/// ByteFallback is a simple trick which converts tokens looking like `<0x61>` -/// to pure bytes, and attempts to make them into a string. If the tokens -/// cannot be decoded you will get οΏ½ instead for each inconvertible byte token -#[non_exhaustive] -pub struct ByteFallback { - #[serde(rename = "type")] - type_: MustBe!("ByteFallback"), +impl Default for ByteFallback { + fn default() -> Self { + ByteFallback + } } impl ByteFallback { pub fn new() -> Self { - Self { - type_: MustBe!("ByteFallback"), - } + ByteFallback } } diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 9d5a571886..c2e529cd3f 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,7 +1,6 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; -use itertools::Itertools; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -43,22 +42,22 @@ impl Default for CTC { impl Decoder for CTC { fn decode_chain(&self, tokens: Vec) -> Result> { - Ok(tokens - .into_iter() - .dedup() - .filter_map(|token| { - let mut replaced = token.replace(&self.pad_token, ""); - if self.cleanup { - replaced = - wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " "); - } - if replaced.is_empty() { - None - } else { - Some(replaced) - } - }) - .collect()) + let mut prev: Option = None; + let mut result = Vec::new(); + for token in tokens { + if prev.as_ref() == Some(&token) { + continue; + } + prev = Some(token.clone()); + let mut replaced = token.replace(&self.pad_token, ""); + if self.cleanup { + replaced = wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " "); + } + if !replaced.is_empty() { + result.push(replaced); + } + } + Ok(result) } } diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 5e4a1c1197..1d0cc269b8 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,23 +1,23 @@ use crate::tokenizer::{Decoder, Result}; -use monostate::MustBe; -use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize, Default)] -/// Fuse simply fuses all tokens into one big string. -/// It's usually the last decoding step anyway, but this -/// decoder exists incase some decoders need to happen after that -/// step -#[non_exhaustive] -pub struct Fuse { - #[serde(rename = "type")] - type_: MustBe!("Fuse"), +impl_serde_type! { + #[derive(Clone, Debug)] + /// Fuse simply fuses all tokens into one big string. + /// It's usually the last decoding step anyway, but this + /// decoder exists incase some decoders need to happen after that + /// step + pub struct Fuse; +} + +impl Default for Fuse { + fn default() -> Self { + Fuse + } } impl Fuse { pub fn new() -> Self { - Self { - type_: MustBe!("Fuse"), - } + Fuse } } diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 1233a8bd2a..9a47a75510 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -124,19 +124,141 @@ //! //! # Features //! -//! - **progressbar**: The progress bar visualization is enabled by default. It might be disabled if -//! compilation for certain targets is not supported by the [termios](https://crates.io/crates/termios) -//! dependency of the [indicatif](https://crates.io/crates/indicatif) progress bar. +//! All features are **enabled by default** for backward compatibility. Disable them for on-device/embedded use. //! -//! - **http**: This feature enables downloading the tokenizer via HTTP. It is disabled by default. -//! With this feature enabled, `Tokenizer::from_pretrained` becomes accessible. +//! | Feature | Default | Description | Deps saved | +//! |---------|---------|-------------|------------| +//! | `training` | on | Tokenizer training (trainers, `train()` method) | rand, esaxx-rs, compact_str | +//! | `parallel` | on | Multi-threaded encoding via rayon | rayon, rayon-cond, crossbeam | +//! | `spm` | on | SentencePiece precompiled normalizer (T5, mBART) | spm_precompiled, nom, unicode-segmentation | +//! | `unicode-normalization` | on | NFC/NFD/NFKC/NFKD normalizers | unicode-normalization-alignments | +//! | `progressbar` | on | Progress bars during training | indicatif | +//! | `onig` | on | Oniguruma regex engine (C binding) | onig, onig_sys | +//! | `http` | off | Download tokenizers from Hugging Face Hub | hf-hub, ureq | +//! | `unstable_wasm` | off | WASM target support (uses fancy-regex) | fancy-regex | +//! +//! ## On-device / embedded configuration +//! +//! ```toml +//! # Minimal inference-only (with Oniguruma regex): +//! tokenizers = { version = "0.22", default-features = false, features = ["onig"] } +//! +//! # WASM (pure Rust, no C dependencies): +//! tokenizers = { version = "0.22", default-features = false, features = ["unstable_wasm"] } +//! ``` +//! +//! # Bundle size +//! +//! The deployed library size depends on how you link it. Here are measured sizes on macOS arm64: +//! +//! | Configuration | .dylib (shared) | .a (static) | After final link | +//! |---------------|----------------|-------------|-----------------| +//! | Default (all features) | 2.5 MB | 9.2 MB | ~2.5 MB | +//! | Inference-only (`onig`) | 2.0 MB | 8.0 MB | ~2.0 MB | +//! +//! > **Note**: `.a` (static archive) files contain all object code including unused functions. +//! > The linker strips dead code at final link time, so the actual contribution to your app +//! > binary is close to the `.dylib` size. The `.a` size is NOT what ships to users. +//! +//! ## Comparison with Meta pytorch/tokenizers (C++) +//! +//! | | Meta (C++) | HuggingFace (Rust) | +//! |---|---|---| +//! | Stripped binary (all tokenizer types) | **0.8 MB** | **2.0 MB** | +//! | Static .a (pre-link, all deps) | 5.5 MB | 8.0 MB | +//! | Features | SP, Tiktoken, Llama2c | BPE, WordPiece, Unigram, WordLevel + normalizers, pre-tokenizers, decoders, added vocab | +//! +//! HuggingFace is ~2.5x larger because it includes full `tokenizer.json` parsing (serde), Unicode-aware +//! regex, all normalizer/pre-tokenizer/decoder types, and added vocabulary matching β€” features Meta's +//! library doesn't have. +//! +//! ## How to measure bundle size +//! +//! **1. Measure the linked shared library (what ships to users):** +//! +//! ```bash +//! # Create a test crate that links tokenizers as a cdylib +//! cargo new --lib measure-size && cd measure-size +//! cat >> Cargo.toml << 'EOF' +//! [lib] +//! crate-type = ["cdylib"] +//! +//! [dependencies] +//! tokenizers = { path = "../tokenizers", default-features = false, features = ["onig"] } +//! +//! [profile.release] +//! lto = "fat" +//! opt-level = "s" +//! strip = true +//! EOF +//! +//! echo 'use tokenizers::Tokenizer; +//! #[no_mangle] +//! pub extern "C" fn tokenize() { let _ = Tokenizer::from_file("t.json"); }' > src/lib.rs +//! +//! cargo build --release +//! ls -lh target/release/*.dylib # macOS +//! ls -lh target/release/*.so # Linux +//! ``` +//! +//! **2. Measure per-crate contribution with cargo-bloat:** +//! +//! ```bash +//! cargo install cargo-bloat +//! cargo bloat --release --crates -n 30 +//! ``` +//! +//! **3. Measure dependency rlib sizes (compile-time cost):** +//! +//! ```bash +//! # Total rlib for runtime deps only +//! cargo tree --edges=normal --prefix none -f '{p}' | awk '{print $1}' | sort -u | sed 's/-/_/g' > /tmp/deps.txt +//! +//! for f in target/release/deps/*.rlib; do +//! sz=$(stat -f%z "$f" 2>/dev/null || stat -c%s "$f" 2>/dev/null) +//! name=$(basename "$f" | sed 's/-[a-f0-9]*\.rlib//' | sed 's/^lib//') +//! echo "$sz $name" +//! done | sort -t' ' -k2 | awk '!seen[$2]++ {print}' | sort -k2 > /tmp/rlibs.txt +//! +//! join -1 2 -2 1 /tmp/rlibs.txt /tmp/deps.txt | awk '{ +//! total+=$2 +//! printf "%8.1f KB %s\n", $2/1024, $1 +//! } END { +//! printf "\nTOTAL: %.1f MB\n", total/1048576 +//! }' | sort -rn +//! ``` +//! +//! **4. Track size in CI (regression test):** +//! +//! ```bash +//! #!/bin/bash +//! # scripts/check-bundle-size.sh +//! set -e +//! +//! MAX_DYLIB_KB=2500 # 2.5 MB threshold +//! +//! cargo build --release --no-default-features --features "onig" \ +//! --target-dir /tmp/size-check +//! +//! SIZE=$(stat -f%z /tmp/size-check/release/libtokenizers.rlib 2>/dev/null \ +//! || stat -c%s /tmp/size-check/release/libtokenizers.rlib) +//! SIZE_KB=$((SIZE / 1024)) +//! +//! echo "libtokenizers.rlib: ${SIZE_KB} KB" +//! +//! # For the actual linked size, build a cdylib test crate +//! # (see step 1 above) and check the .dylib/.so size +//! +//! if [ "$SIZE_KB" -gt "$MAX_DYLIB_KB" ]; then +//! echo "FAIL: bundle size ${SIZE_KB} KB exceeds threshold ${MAX_DYLIB_KB} KB" +//! exit 1 +//! fi +//! echo "PASS: bundle size OK" +//! ``` #[macro_use] extern crate log; -#[macro_use] -extern crate derive_builder; - #[macro_use] pub mod utils; pub mod decoders; diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df6..f337f9b9fa 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -1,8 +1,10 @@ //! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +#[cfg(feature = "training")] use std::{iter, mem}; mod model; mod serialization; +#[cfg(feature = "training")] pub mod trainer; mod word; @@ -35,11 +37,13 @@ pub enum Error { InvalidDropout, } +#[cfg(feature = "training")] /// Provides access to the `FirstLastIterator` to any Iterator pub(crate) trait WithFirstLastIterator: Iterator + Sized { fn with_first_and_last(self) -> FirstLastIterator; } +#[cfg(feature = "training")] impl WithFirstLastIterator for I where I: Iterator, @@ -52,6 +56,7 @@ where } } +#[cfg(feature = "training")] /// Provides information about whether an item is the first and/or the last of the iterator pub(crate) struct FirstLastIterator where @@ -61,6 +66,7 @@ where iter: iter::Peekable, } +#[cfg(feature = "training")] impl Iterator for FirstLastIterator where I: Iterator, @@ -78,5 +84,6 @@ where // Re-export pub use model::*; +#[cfg(feature = "training")] pub use trainer::*; use word::*; diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index c0e4f7d84d..9ba46ebe19 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,8 +1,10 @@ -use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; +#[cfg(feature = "training")] +use super::trainer::BpeTrainer; +use super::{super::OrderedVocabIter, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; -use ahash::AHashMap; +use crate::utils::{AHashMap, HashMapExt}; use serde_json::Value; use std::borrow::Cow; @@ -510,6 +512,7 @@ impl BPE { } impl Model for BPE { + #[cfg(feature = "training")] type Trainer = BpeTrainer; fn get_vocab(&self) -> HashMap { @@ -585,6 +588,7 @@ impl Model for BPE { Ok(vec![vocab_path, merges_path]) } + #[cfg(feature = "training")] fn get_trainer(&self) -> BpeTrainer { BpeTrainer::default() } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cf549445..3c0bdfb1a3 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,5 +1,5 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; -use ahash::AHashMap; +use crate::utils::AHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index df68c655e9..3721d66584 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,7 +4,7 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle}; -use ahash::{AHashMap, AHashSet}; +use crate::utils::{AHashMap, AHashSet, HashMapExt, HashSetExt}; use compact_str::CompactString; use dary_heap::OctonaryHeap; use serde::{Deserialize, Serialize}; @@ -608,9 +608,9 @@ impl BpeTrainer { // Transfer new vocab & options to model //model.vocab = word_to_id; model.vocab = word_to_id - .into_iter() + .into_values() // we have to look up the string in id_to_word because the key in word_to_id is a hash - .map(|(_key, val)| (id_to_word[val as usize].to_string(), val)) + .map(|val| (id_to_word[val as usize].to_string(), val)) .collect(); model.vocab_r = model .vocab @@ -678,7 +678,7 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; - use ahash::AHashMap; + use crate::utils::AHashMap; use compact_str::CompactString; #[test] diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 7bf2dee566..a4b75b113f 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,6 +1,7 @@ use super::Pair; -use ahash::AHashMap; +use crate::utils::AHashMap; use dary_heap::QuaternaryHeap; +#[cfg(feature = "training")] use rand::{rng, Rng}; use std::cmp::Ordering; @@ -75,6 +76,7 @@ impl std::fmt::Debug for Word { } impl Word { + #[cfg_attr(not(feature = "training"), allow(dead_code))] pub(super) fn new() -> Self { Word { symbols: vec![] } } @@ -104,6 +106,7 @@ impl Word { }); } + #[cfg_attr(not(feature = "training"), allow(dead_code))] pub(super) fn merge( &mut self, c1: u32, @@ -178,7 +181,18 @@ impl Word { ); while let Some(top) = queue.pop() { - if dropout.map(|d| rng().random::() < d).unwrap_or(false) { + let should_skip = { + #[cfg(feature = "training")] + { + dropout.map(|d| rng().random::() < d).unwrap_or(false) + } + #[cfg(not(feature = "training"))] + { + let _ = &dropout; + false + } + }; + if should_skip { skip.push(top); } else { // Re-insert the skipped elements @@ -249,6 +263,7 @@ impl Word { self.symbols.retain(|s| s.len != 0); } + #[cfg_attr(not(feature = "training"), allow(dead_code))] pub(super) fn get_chars(&self) -> Vec { self.symbols.iter().map(|s| s.c).collect() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 041e3b629b..cdb2ac8d8f 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,17 +5,27 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; -use ahash::AHashMap; +use crate::utils::AHashMap; use std::collections::HashMap; use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::models::bpe::{BpeTrainer, BPE}; -use crate::models::unigram::{Unigram, UnigramTrainer}; -use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; -use crate::models::wordpiece::{WordPiece, WordPieceTrainer}; -use crate::{AddedToken, Model, Result, Token, Trainer}; +#[cfg(feature = "training")] +use crate::models::bpe::BpeTrainer; +use crate::models::bpe::BPE; +use crate::models::unigram::Unigram; +#[cfg(feature = "training")] +use crate::models::unigram::UnigramTrainer; +use crate::models::wordlevel::WordLevel; +#[cfg(feature = "training")] +use crate::models::wordlevel::WordLevelTrainer; +use crate::models::wordpiece::WordPiece; +#[cfg(feature = "training")] +use crate::models::wordpiece::WordPieceTrainer; +#[cfg(feature = "training")] +use crate::{AddedToken, Trainer}; +use crate::{Model, Result, Token}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. @@ -141,6 +151,7 @@ impl_enum_from!(BPE, ModelWrapper, BPE); impl_enum_from!(Unigram, ModelWrapper, Unigram); impl Model for ModelWrapper { + #[cfg(feature = "training")] type Trainer = TrainerWrapper; fn tokenize(&self, tokens: &str) -> Result> { @@ -197,6 +208,7 @@ impl Model for ModelWrapper { } } + #[cfg(feature = "training")] fn get_trainer(&self) -> Self::Trainer { match self { Self::WordLevel(t) => t.get_trainer().into(), @@ -224,6 +236,7 @@ impl ModelWrapper { } } +#[cfg(feature = "training")] #[derive(Clone, Serialize, Deserialize)] pub enum TrainerWrapper { BpeTrainer(BpeTrainer), @@ -232,6 +245,7 @@ pub enum TrainerWrapper { UnigramTrainer(UnigramTrainer), } +#[cfg(feature = "training")] impl Trainer for TrainerWrapper { type Model = ModelWrapper; @@ -280,9 +294,13 @@ impl Trainer for TrainerWrapper { } } +#[cfg(feature = "training")] impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); +#[cfg(feature = "training")] impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); +#[cfg(feature = "training")] impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); +#[cfg(feature = "training")] impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); #[cfg(test)] @@ -302,7 +320,7 @@ mod tests { #[test] fn incomplete_ordered_vocab() { let vocab_r: AHashMap = - AHashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + IntoIterator::into_iter([(0u32, "Hi".to_string()), (2, "There".to_string())]).collect(); let ordered = OrderedVocabIter::new(&vocab_r); diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index 5464671f1a..93e75b1cb8 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -1,5 +1,7 @@ use dary_heap::QuaternaryHeap; +#[cfg(feature = "training")] use rand::distr::weighted::WeightedIndex; +#[cfg(feature = "training")] use rand::{prelude::*, rng}; use std::cell::RefCell; use std::cmp::{min, Ordering}; @@ -377,6 +379,7 @@ impl<'a> Lattice<'a> { freq * z } + #[cfg(feature = "training")] pub fn sample(&self, theta: f64) -> Vec { let len = self.len(); if len == 0 { @@ -422,6 +425,7 @@ impl<'a> Lattice<'a> { results } + #[cfg(feature = "training")] pub fn sample_token(&self, theta: f64) -> Vec { self.sample(theta) .iter() @@ -429,6 +433,7 @@ impl<'a> Lattice<'a> { .collect() } + #[cfg(feature = "training")] pub fn sample_nbest(&mut self, n: usize, theta: f64) -> Vec { let nbest_paths = self.nbest(n); if nbest_paths.is_empty() { diff --git a/tokenizers/src/models/unigram/mod.rs b/tokenizers/src/models/unigram/mod.rs index d408b5c8f0..f219cb80b6 100644 --- a/tokenizers/src/models/unigram/mod.rs +++ b/tokenizers/src/models/unigram/mod.rs @@ -2,9 +2,11 @@ mod lattice; mod model; mod serialization; +#[cfg(feature = "training")] mod trainer; mod trie; pub use lattice::*; pub use model::*; +#[cfg(feature = "training")] pub use trainer::*; diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 3a9a6bddbd..488065a5f5 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -1,13 +1,14 @@ +#[cfg(feature = "training")] +use super::trainer::UnigramTrainer; use super::{ lattice::Lattice, - trainer::UnigramTrainer, trie::{Trie, TrieBuilder}, }; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, MAX_LENGTH}; use std::collections::HashMap; -use ahash::AHashMap; +use crate::utils::{AHashMap, HashMapExt}; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; @@ -346,10 +347,19 @@ impl Unigram { fn encode_unoptimized(&self, sentence: &str) -> Result> { let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id); self.populate_nodes(&mut lattice); - let path = match (self.nbest_size, self.alpha) { - (Some(n), Some(alpha)) if n > 0 => lattice.sample_nbest(n, alpha), - (_, Some(alpha)) => lattice.sample(alpha), - _ => lattice.viterbi(), + let path = { + #[cfg(feature = "training")] + { + match (self.nbest_size, self.alpha) { + (Some(n), Some(alpha)) if n > 0 => lattice.sample_nbest(n, alpha), + (_, Some(alpha)) => lattice.sample(alpha), + _ => lattice.viterbi(), + } + } + #[cfg(not(feature = "training"))] + { + lattice.viterbi() + } }; if self.fuse_unk { let mut results = vec![]; @@ -430,6 +440,7 @@ impl<'a> Iterator for UnigramIterator<'a> { } impl Model for Unigram { + #[cfg(feature = "training")] type Trainer = UnigramTrainer; fn get_vocab(&self) -> HashMap { @@ -497,6 +508,7 @@ impl Model for Unigram { Ok(vec![fullpath]) } + #[cfg(feature = "training")] fn get_trainer(&self) -> Self::Trainer { UnigramTrainer::default() } diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index ff5ca9428a..575701f648 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,7 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; -use ahash::{AHashMap, AHashSet}; +use crate::utils::{AHashMap, AHashSet, HashMapExt, HashSetExt}; use log::debug; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; @@ -45,35 +45,108 @@ fn to_log_prob(pieces: &mut [SentencePiece]) { /// A `UnigramTrainer` can train a `Unigram` model from `word_counts`. #[non_exhaustive] -#[derive(Builder, Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct UnigramTrainer { - #[builder(default = "true")] pub show_progress: bool, - #[builder(default = "8000")] pub vocab_size: u32, - #[builder(default = "2")] pub n_sub_iterations: u32, - #[builder(default = "0.75")] pub shrinking_factor: f64, - #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "AHashSet::new()")] pub initial_alphabet: AHashSet, - - #[builder(default = "None")] pub unk_token: Option, - - #[builder(default = "16")] pub max_piece_length: usize, - #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "AHashMap::new()")] words: AHashMap, } impl Default for UnigramTrainer { fn default() -> Self { - Self::builder().build().unwrap() + Self { + show_progress: true, + vocab_size: 8000, + n_sub_iterations: 2, + shrinking_factor: 0.75, + special_tokens: vec![], + initial_alphabet: AHashSet::new(), + unk_token: None, + max_piece_length: 16, + seed_size: 1_000_000, + words: AHashMap::new(), + } + } +} + +/// Builder for `UnigramTrainer`. +#[derive(Debug, Clone, Default)] +pub struct UnigramTrainerBuilder { + show_progress: Option, + vocab_size: Option, + n_sub_iterations: Option, + shrinking_factor: Option, + special_tokens: Option>, + initial_alphabet: Option>, + unk_token: Option>, + max_piece_length: Option, + seed_size: Option, +} + +impl UnigramTrainerBuilder { + pub fn show_progress(&mut self, show_progress: bool) -> &mut Self { + self.show_progress = Some(show_progress); + self + } + pub fn vocab_size(&mut self, vocab_size: u32) -> &mut Self { + self.vocab_size = Some(vocab_size); + self + } + pub fn n_sub_iterations(&mut self, n_sub_iterations: u32) -> &mut Self { + self.n_sub_iterations = Some(n_sub_iterations); + self + } + pub fn shrinking_factor(&mut self, shrinking_factor: f64) -> &mut Self { + self.shrinking_factor = Some(shrinking_factor); + self + } + pub fn special_tokens(&mut self, special_tokens: Vec) -> &mut Self { + self.special_tokens = Some(special_tokens); + self + } + pub fn initial_alphabet(&mut self, initial_alphabet: AHashSet) -> &mut Self { + self.initial_alphabet = Some(initial_alphabet); + self + } + pub fn unk_token(&mut self, unk_token: Option) -> &mut Self { + self.unk_token = Some(unk_token); + self + } + pub fn max_piece_length(&mut self, max_piece_length: usize) -> &mut Self { + self.max_piece_length = Some(max_piece_length); + self + } + pub fn seed_size(&mut self, seed_size: usize) -> &mut Self { + self.seed_size = Some(seed_size); + self + } + pub fn build(&self) -> Result { + let default = UnigramTrainer::default(); + Ok(UnigramTrainer { + show_progress: self.show_progress.unwrap_or(default.show_progress), + vocab_size: self.vocab_size.unwrap_or(default.vocab_size), + n_sub_iterations: self.n_sub_iterations.unwrap_or(default.n_sub_iterations), + shrinking_factor: self.shrinking_factor.unwrap_or(default.shrinking_factor), + special_tokens: self + .special_tokens + .clone() + .unwrap_or(default.special_tokens), + initial_alphabet: self + .initial_alphabet + .clone() + .unwrap_or(default.initial_alphabet), + unk_token: self.unk_token.clone().unwrap_or(default.unk_token), + max_piece_length: self.max_piece_length.unwrap_or(default.max_piece_length), + seed_size: self.seed_size.unwrap_or(default.seed_size), + words: AHashMap::new(), + }) } } diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index 7c7149d00a..894cae3cf7 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -1,4 +1,4 @@ -use ahash::AHashMap; +use crate::utils::{AHashMap, HashMapExt}; use std::hash::Hash; #[derive(Default)] diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 94e7c86b4f..bd62f1b4e8 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,6 +1,6 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; -use ahash::AHashMap; +use crate::utils::{AHashMap, HashMapExt}; use serde_json::Value; use std::collections::HashMap; use std::fs::File; @@ -8,9 +8,11 @@ use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; mod serialization; +#[cfg(feature = "training")] mod trainer; // Re-export +#[cfg(feature = "training")] pub use trainer::*; type Vocab = AHashMap; @@ -157,6 +159,7 @@ impl Default for WordLevel { } impl Model for WordLevel { + #[cfg(feature = "training")] type Trainer = WordLevelTrainer; fn tokenize(&self, token: &str) -> Result> { @@ -211,6 +214,7 @@ impl Model for WordLevel { Ok(vec![vocab_path]) } + #[cfg(feature = "training")] fn get_trainer(&self) -> Self::Trainer { WordLevelTrainer::default() } diff --git a/tokenizers/src/models/wordlevel/serialization.rs b/tokenizers/src/models/wordlevel/serialization.rs index 1cc79339e0..281e217ea9 100644 --- a/tokenizers/src/models/wordlevel/serialization.rs +++ b/tokenizers/src/models/wordlevel/serialization.rs @@ -1,5 +1,5 @@ use super::{super::OrderedVocabIter, WordLevel, WordLevelBuilder}; -use ahash::AHashSet; +use crate::utils::AHashSet; use serde::{ de::{MapAccess, Visitor}, ser::SerializeStruct, diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index bf980b0d32..3e2d39d484 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -1,33 +1,75 @@ use super::WordLevel; use crate::utils::parallelism::*; +use crate::utils::{AHashMap, HashMapExt}; use crate::{AddedToken, Result, Trainer}; -use ahash::AHashMap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; #[non_exhaustive] -#[derive(Debug, Clone, Builder, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct WordLevelTrainer { /// The minimum frequency a word must have to be part of the vocabulary - #[builder(default = "0")] pub min_frequency: u64, /// The target vocabulary size - #[builder(default = "30_000")] pub vocab_size: usize, /// Whether to show progress while training - #[builder(default = "true")] pub show_progress: bool, /// A list of special tokens that the model should know of - #[builder(default)] pub special_tokens: Vec, - #[builder(default, private)] words: AHashMap, } impl Default for WordLevelTrainer { fn default() -> Self { - Self::builder().build().unwrap() + Self { + min_frequency: 0, + vocab_size: 30_000, + show_progress: true, + special_tokens: vec![], + words: AHashMap::new(), + } + } +} + +/// Builder for `WordLevelTrainer`. +#[derive(Debug, Clone, Default)] +pub struct WordLevelTrainerBuilder { + min_frequency: Option, + vocab_size: Option, + show_progress: Option, + special_tokens: Option>, +} + +impl WordLevelTrainerBuilder { + pub fn min_frequency(&mut self, min_frequency: u64) -> &mut Self { + self.min_frequency = Some(min_frequency); + self + } + pub fn vocab_size(&mut self, vocab_size: usize) -> &mut Self { + self.vocab_size = Some(vocab_size); + self + } + pub fn show_progress(&mut self, show_progress: bool) -> &mut Self { + self.show_progress = Some(show_progress); + self + } + pub fn special_tokens(&mut self, special_tokens: Vec) -> &mut Self { + self.special_tokens = Some(special_tokens); + self + } + pub fn build(&self) -> Result { + let default = WordLevelTrainer::default(); + Ok(WordLevelTrainer { + min_frequency: self.min_frequency.unwrap_or(default.min_frequency), + vocab_size: self.vocab_size.unwrap_or(default.vocab_size), + show_progress: self.show_progress.unwrap_or(default.show_progress), + special_tokens: self + .special_tokens + .clone() + .unwrap_or(default.special_tokens), + words: AHashMap::new(), + }) } } diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 61fa44f071..790094cb9d 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,7 +3,7 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; -use ahash::AHashMap; +use crate::utils::{AHashMap, HashMapExt}; use std::collections::HashMap; use std::{ borrow::Cow, @@ -14,7 +14,9 @@ use std::{ }; mod serialization; +#[cfg(feature = "training")] mod trainer; +#[cfg(feature = "training")] pub use trainer::*; #[derive(thiserror::Error, Debug)] @@ -211,6 +213,7 @@ impl WordPiece { } impl Model for WordPiece { + #[cfg(feature = "training")] type Trainer = WordPieceTrainer; fn get_vocab(&self) -> HashMap { @@ -313,6 +316,7 @@ impl Model for WordPiece { Ok(vec![vocab_path]) } + #[cfg(feature = "training")] fn get_trainer(&self) -> Self::Trainer { WordPieceTrainer::builder().build() } diff --git a/tokenizers/src/models/wordpiece/serialization.rs b/tokenizers/src/models/wordpiece/serialization.rs index 7ba496d63c..f832b6753a 100644 --- a/tokenizers/src/models/wordpiece/serialization.rs +++ b/tokenizers/src/models/wordpiece/serialization.rs @@ -1,5 +1,5 @@ use super::{super::OrderedVocabIter, WordPiece, WordPieceBuilder}; -use ahash::{AHashMap, AHashSet}; +use crate::utils::{AHashMap, AHashSet}; use serde::{ de::{MapAccess, Visitor}, ser::SerializeStruct, diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 29d4561521..cba8590dc7 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use super::WordPiece; use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE}; use crate::tokenizer::{AddedToken, Result, Trainer}; -use ahash::AHashSet; +use crate::utils::{AHashSet, HashSetExt}; use serde::{Deserialize, Serialize}; /// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 90d982c680..81d8adb89e 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -108,7 +108,12 @@ impl BertNormalizer { } fn do_strip_accents(&self, normalized: &mut NormalizedString) { + #[cfg(feature = "unicode-normalization")] normalized.nfd().filter(|c| !c.is_mark_nonspacing()); + #[cfg(not(feature = "unicode-normalization"))] + { + let _ = normalized; + } } fn do_lowercase(&self, normalized: &mut NormalizedString) { diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index 41fd416156..f23d0cefb9 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -1,7 +1,7 @@ use crate::processors::byte_level::bytes_char; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -use ahash::{AHashMap, AHashSet}; +use crate::utils::{AHashMap, AHashSet}; use std::sync::LazyLock; #[derive(Clone, Debug)] diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index f400f13da9..d0a0989d4e 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -1,5 +1,6 @@ pub mod bert; pub mod byte_level; +#[cfg(feature = "spm")] pub mod precompiled; pub mod prepend; pub mod replace; @@ -8,11 +9,14 @@ pub mod unicode; pub mod utils; pub use crate::normalizers::bert::BertNormalizer; pub use crate::normalizers::byte_level::ByteLevel; +#[cfg(feature = "spm")] pub use crate::normalizers::precompiled::Precompiled; pub use crate::normalizers::prepend::Prepend; pub use crate::normalizers::replace::Replace; pub use crate::normalizers::strip::{Strip, StripAccents}; -pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; +pub use crate::normalizers::unicode::Nmt; +#[cfg(feature = "unicode-normalization")] +pub use crate::normalizers::unicode::{NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Deserializer, Serialize}; @@ -25,13 +29,18 @@ pub enum NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), StripAccents(StripAccents), + #[cfg(feature = "unicode-normalization")] NFC(NFC), + #[cfg(feature = "unicode-normalization")] NFD(NFD), + #[cfg(feature = "unicode-normalization")] NFKC(NFKC), + #[cfg(feature = "unicode-normalization")] NFKD(NFKD), Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), + #[cfg(feature = "spm")] Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), @@ -81,14 +90,17 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { BertNormalizer(BertNormalizer), StripNormalizer(Strip), StripAccents(StripAccents), + #[cfg(feature = "unicode-normalization")] NFC(NFC), + #[cfg(feature = "unicode-normalization")] NFD(NFD), + #[cfg(feature = "unicode-normalization")] NFKC(NFKC), + #[cfg(feature = "unicode-normalization")] NFKD(NFKD), Sequence(Sequence), Lowercase(Lowercase), Nmt(Nmt), - Precompiled(Precompiled), Replace(Replace), Prepend(Prepend), ByteLevel(ByteLevel), @@ -114,18 +126,62 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { EnumType::StripAccents => NormalizerWrapper::StripAccents( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), - EnumType::NFC => NormalizerWrapper::NFC( - serde_json::from_value(values).map_err(serde::de::Error::custom)?, - ), - EnumType::NFD => NormalizerWrapper::NFD( - serde_json::from_value(values).map_err(serde::de::Error::custom)?, - ), - EnumType::NFKC => NormalizerWrapper::NFKC( - serde_json::from_value(values).map_err(serde::de::Error::custom)?, - ), - EnumType::NFKD => NormalizerWrapper::NFKD( - serde_json::from_value(values).map_err(serde::de::Error::custom)?, - ), + EnumType::NFC => { + #[cfg(feature = "unicode-normalization")] + { + NormalizerWrapper::NFC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ) + } + #[cfg(not(feature = "unicode-normalization"))] + { + return Err(serde::de::Error::custom( + "NFC normalizer requires the `unicode-normalization` feature", + )); + } + } + EnumType::NFD => { + #[cfg(feature = "unicode-normalization")] + { + NormalizerWrapper::NFD( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ) + } + #[cfg(not(feature = "unicode-normalization"))] + { + return Err(serde::de::Error::custom( + "NFD normalizer requires the `unicode-normalization` feature", + )); + } + } + EnumType::NFKC => { + #[cfg(feature = "unicode-normalization")] + { + NormalizerWrapper::NFKC( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ) + } + #[cfg(not(feature = "unicode-normalization"))] + { + return Err(serde::de::Error::custom( + "NFKC normalizer requires the `unicode-normalization` feature", + )); + } + } + EnumType::NFKD => { + #[cfg(feature = "unicode-normalization")] + { + NormalizerWrapper::NFKD( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ) + } + #[cfg(not(feature = "unicode-normalization"))] + { + return Err(serde::de::Error::custom( + "NFKD normalizer requires the `unicode-normalization` feature", + )); + } + } EnumType::Sequence => NormalizerWrapper::Sequence( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), @@ -135,13 +191,24 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { EnumType::Nmt => NormalizerWrapper::Nmt( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), - EnumType::Precompiled => NormalizerWrapper::Precompiled( - serde_json::from_str( - &serde_json::to_string(&values).expect("Can reserialize precompiled"), - ) - // .map_err(serde::de::Error::custom) - .expect("Precompiled"), - ), + EnumType::Precompiled => { + #[cfg(feature = "spm")] + { + NormalizerWrapper::Precompiled( + serde_json::from_str( + &serde_json::to_string(&values) + .expect("Can reserialize precompiled"), + ) + .expect("Precompiled"), + ) + } + #[cfg(not(feature = "spm"))] + { + return Err(serde::de::Error::custom( + "Precompiled normalizer requires the `spm` feature", + )); + } + } EnumType::Replace => NormalizerWrapper::Replace( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), @@ -164,14 +231,17 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { NormalizerWrapper::StripNormalizer(bpe) } NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe), + #[cfg(feature = "unicode-normalization")] NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe), + #[cfg(feature = "unicode-normalization")] NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe), + #[cfg(feature = "unicode-normalization")] NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe), + #[cfg(feature = "unicode-normalization")] NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe), NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq), NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe), NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe), - NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe), NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe), NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe), NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe), @@ -187,13 +257,18 @@ impl Normalizer for NormalizerWrapper { Self::BertNormalizer(bn) => bn.normalize(normalized), Self::StripNormalizer(sn) => sn.normalize(normalized), Self::StripAccents(sn) => sn.normalize(normalized), + #[cfg(feature = "unicode-normalization")] Self::NFC(nfc) => nfc.normalize(normalized), + #[cfg(feature = "unicode-normalization")] Self::NFD(nfd) => nfd.normalize(normalized), + #[cfg(feature = "unicode-normalization")] Self::NFKC(nfkc) => nfkc.normalize(normalized), + #[cfg(feature = "unicode-normalization")] Self::NFKD(nfkd) => nfkd.normalize(normalized), Self::Sequence(sequence) => sequence.normalize(normalized), Self::Lowercase(lc) => lc.normalize(normalized), Self::Nmt(lc) => lc.normalize(normalized), + #[cfg(feature = "spm")] Self::Precompiled(lc) => lc.normalize(normalized), Self::Replace(lc) => lc.normalize(normalized), Self::Prepend(lc) => lc.normalize(normalized), @@ -203,15 +278,20 @@ impl Normalizer for NormalizerWrapper { } impl_enum_from!(BertNormalizer, NormalizerWrapper, BertNormalizer); +#[cfg(feature = "unicode-normalization")] impl_enum_from!(NFKD, NormalizerWrapper, NFKD); +#[cfg(feature = "unicode-normalization")] impl_enum_from!(NFKC, NormalizerWrapper, NFKC); +#[cfg(feature = "unicode-normalization")] impl_enum_from!(NFC, NormalizerWrapper, NFC); +#[cfg(feature = "unicode-normalization")] impl_enum_from!(NFD, NormalizerWrapper, NFD); impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer); impl_enum_from!(StripAccents, NormalizerWrapper, StripAccents); impl_enum_from!(Sequence, NormalizerWrapper, Sequence); impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase); impl_enum_from!(Nmt, NormalizerWrapper, Nmt); +#[cfg(feature = "spm")] impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled); impl_enum_from!(Replace, NormalizerWrapper, Replace); impl_enum_from!(Prepend, NormalizerWrapper, Prepend); diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 5657574830..f9730802a8 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -67,7 +67,7 @@ impl Replace { pub fn new, C: Into>(pattern: I, content: C) -> Result { let pattern: ReplacePattern = pattern.into(); let regex = match &pattern { - ReplacePattern::String(s) => SysRegex::new(®ex::escape(s))?, + ReplacePattern::String(s) => SysRegex::new(&crate::utils::regex_escape(s))?, ReplacePattern::Regex(r) => SysRegex::new(r)?, }; diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 19f5ff314d..f618be4d7e 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,8 +1,15 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; use serde::{Deserialize, Serialize}; +#[cfg(feature = "unicode-normalization")] use unicode_normalization_alignments::char::is_combining_mark; +#[cfg(not(feature = "unicode-normalization"))] +fn is_combining_mark(_c: char) -> bool { + // Without unicode-normalization feature, accent stripping is a no-op. + false +} + #[derive(Copy, Clone, Debug, Deserialize, Serialize)] #[serde(tag = "type")] #[non_exhaustive] diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 502b4239b4..2ec4be0634 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -1,46 +1,54 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::macro_rules_attribute; -#[derive(Default, Copy, Clone, Debug)] -#[macro_rules_attribute(impl_serde_type!)] -pub struct NFD; -impl Normalizer for NFD { - fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { - normalized.nfd(); - Ok(()) +#[cfg(feature = "unicode-normalization")] +mod nf_normalizers { + use super::*; + + #[derive(Default, Copy, Clone, Debug)] + #[macro_rules_attribute(impl_serde_type!)] + pub struct NFD; + impl Normalizer for NFD { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + normalized.nfd(); + Ok(()) + } } -} -#[derive(Default, Copy, Clone, Debug)] -#[macro_rules_attribute(impl_serde_type!)] -pub struct NFKD; -impl Normalizer for NFKD { - fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { - normalized.nfkd(); - Ok(()) + #[derive(Default, Copy, Clone, Debug)] + #[macro_rules_attribute(impl_serde_type!)] + pub struct NFKD; + impl Normalizer for NFKD { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + normalized.nfkd(); + Ok(()) + } } -} -#[derive(Default, Copy, Clone, Debug)] -#[macro_rules_attribute(impl_serde_type!)] -pub struct NFC; -impl Normalizer for NFC { - fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { - normalized.nfc(); - Ok(()) + #[derive(Default, Copy, Clone, Debug)] + #[macro_rules_attribute(impl_serde_type!)] + pub struct NFC; + impl Normalizer for NFC { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + normalized.nfc(); + Ok(()) + } } -} -#[derive(Default, Copy, Clone, Debug)] -#[macro_rules_attribute(impl_serde_type!)] -pub struct NFKC; -impl Normalizer for NFKC { - fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { - normalized.nfkc(); - Ok(()) + #[derive(Default, Copy, Clone, Debug)] + #[macro_rules_attribute(impl_serde_type!)] + pub struct NFKC; + impl Normalizer for NFKC { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + normalized.nfkc(); + Ok(()) + } } } +#[cfg(feature = "unicode-normalization")] +pub use nf_normalizers::*; + fn do_nmt(normalized: &mut NormalizedString) { // Ascii Control characters normalized diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af0..13ca230625 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,4 +1,6 @@ -use ahash::{AHashMap, AHashSet}; +#[cfg(test)] +use crate::utils::HashMapExt; +use crate::utils::{AHashMap, AHashSet}; use std::sync::LazyLock; use crate::utils::SysRegex; diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index d821f11841..c781fa34ec 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -174,8 +174,6 @@ impl Decoder for Metaspace { #[cfg(test)] mod tests { - use regex::Regex; - use super::*; use crate::{OffsetReferential, OffsetType}; @@ -278,7 +276,7 @@ mod tests { let pretok = Metaspace::new('▁', PrependScheme::First, false); let mut pretokenized = PreTokenizedString::from("Hey my friend how▁are you"); - let re_ref = Regex::new(r"()").unwrap(); + let re_ref = crate::utils::SysRegex::new(r"()").unwrap(); pretokenized .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated)) .expect("Bad split"); diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 5f7362f71e..c176901ba1 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -80,7 +80,7 @@ impl Split { ) -> Result { let pattern: SplitPattern = pattern.into(); let regex = match &pattern { - SplitPattern::String(s) => SysRegex::new(®ex::escape(s))?, + SplitPattern::String(s) => SysRegex::new(&crate::utils::regex_escape(s))?, SplitPattern::Regex(r) => SysRegex::new(r)?, }; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 20cfb65193..15bb115c29 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,11 +1,10 @@ use std::sync::LazyLock; -use regex::Regex; - use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; use crate::utils::macro_rules_attribute; +use crate::utils::SysRegex; #[derive(Clone, Debug, PartialEq, Eq)] #[macro_rules_attribute(impl_serde_type!)] @@ -19,8 +18,8 @@ impl Default for Whitespace { impl PreTokenizer for Whitespace { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { - static RE: LazyLock = LazyLock::new(|| Regex::new(r"\w+|[^\w\s]+").unwrap()); - let re_ref: &Regex = &RE; + static RE: LazyLock = LazyLock::new(|| SysRegex::new(r"\w+|[^\w\s]+").unwrap()); + let re_ref: &SysRegex = &RE; pretokenized.split(|_, normalized| { normalized.split(Invert(re_ref), SplitDelimiterBehavior::Removed) diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index a1cab8abd1..7f2907d7ba 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,5 +1,5 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use ahash::AHashMap; +use crate::utils::AHashMap; use serde::{Deserialize, Serialize}; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index f2a47a9d38..5e7a3ca0be 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,6 +1,6 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; -use ahash::AHashMap; +use crate::utils::AHashMap; use serde::{Deserialize, Serialize}; use std::iter::FromIterator; diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index f44cf54ac8..8043a455ae 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -73,7 +73,7 @@ mod tests { use super::*; use crate::processors::{ByteLevel, PostProcessorWrapper}; use crate::tokenizer::{Encoding, PostProcessor}; - use ahash::AHashMap; + use crate::utils::{AHashMap, HashMapExt}; use std::iter::FromIterator; #[test] diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 50fac99dfc..71377672dc 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -56,9 +56,8 @@ //! //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! +use crate::utils::{AHashMap, AHashSet, HashMapExt}; use crate::{Encoding, PostProcessor, Result}; -use ahash::{AHashMap, AHashSet}; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::convert::{TryFrom, TryInto}; use std::result::Result as StdResult; @@ -333,21 +332,15 @@ impl From> for Tokens { /// .unwrap(); /// ``` /// -#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] #[serde(tag = "type", from = "TemplateProcessingDeserializer")] -#[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { - #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] pub single: Template, - #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] pair: Template, - #[builder(setter(skip), default = "self.default_added(true)")] #[serde(skip)] added_single: usize, - #[builder(setter(skip), default = "self.default_added(false)")] #[serde(skip)] added_pair: usize, - #[builder(setter(into), default)] special_tokens: Tokens, } @@ -405,7 +398,13 @@ impl TemplateProcessing { impl From<&str> for TemplateProcessingBuilderError { fn from(e: &str) -> Self { - e.to_string().into() + TemplateProcessingBuilderError(e.to_string()) + } +} + +impl From for TemplateProcessingBuilderError { + fn from(e: String) -> Self { + TemplateProcessingBuilderError(e) } } @@ -439,33 +438,66 @@ impl From for TemplateProcessing { } } -/// Count the number of added tokens in the given template -fn count_added(container: &Template, special_tokens: Option<&Tokens>) -> usize { - container - .0 - .iter() - .map(|p| match p { - Piece::Sequence { .. } => 0, - Piece::SpecialToken { id, .. } => { - special_tokens.map_or(0, |spt| spt.0.get(id).map_or(0, |s| s.ids.len())) - } - }) - .sum() +/// Error type for `TemplateProcessingBuilder`. +#[derive(Debug, Clone)] +pub struct TemplateProcessingBuilderError(String); + +impl std::fmt::Display for TemplateProcessingBuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for TemplateProcessingBuilderError {} + +/// Builder for `TemplateProcessing`. +#[derive(Debug, Clone, Default)] +pub struct TemplateProcessingBuilder { + single: Option