-
Notifications
You must be signed in to change notification settings - Fork 100
Expand file tree
/
Copy pathmicrotokenizer.py
More file actions
239 lines (193 loc) · 9.89 KB
/
microtokenizer.py
File metadata and controls
239 lines (193 loc) · 9.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
How text becomes numbers -- the compression algorithm hiding inside every LLM.
Byte-Pair Encoding learns a vocabulary by iteratively merging the most frequent
adjacent token pairs, then encodes new text by replaying those merges in priority order.
"""
# Reference: Philip Gage, "A New Algorithm for Data Compression" (1994).
# GPT-2's byte-level BPE variant (Radford et al., 2019) starts from raw bytes
# rather than characters -- that's the version implemented here.
# === TRADEOFFS ===
# + Data-driven vocabulary: adapts to any language or domain without linguistic rules
# + Compresses common patterns into single tokens, reducing sequence lengths
# + Byte-level fallback guarantees encoding of any input (no unknown tokens)
# - Merge order is corpus-dependent: different data yields different tokenizations
# - Encoding is O(n * vocab_size) per merge step without optimized data structures
# - Subword boundaries rarely align with morpheme boundaries (linguistic blindness)
# WHEN TO USE: Preprocessing text for any neural language model. BPE is the
# standard tokenizer for GPT-family models.
# WHEN NOT TO: Character-level tasks where token boundaries matter (e.g., spelling),
# or languages where word segmentation is well-solved by existing tools.
from __future__ import annotations
import os
import random
import urllib.request
from collections import Counter
random.seed(42) # repo convention; BPE itself is fully deterministic
# === CONSTANTS ===
NUM_MERGES = 256 # Final vocab = 256 byte tokens + 256 merges = 512 tokens.
# Signpost: production tokenizers (GPT-2, GPT-4) use 50K+ merges trained on
# hundreds of gigabytes. 256 merges on 18KB is a toy, but the algorithm is identical.
DATA_URL = "https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
DATA_FILE = "names.txt"
# === DATA LOADING ===
def load_data(url: str, filename: str) -> bytes:
"""Download dataset if not cached, return raw bytes."""
if not os.path.exists(filename):
print(f"Downloading {filename}...")
urllib.request.urlretrieve(url, filename)
with open(filename, "rb") as f:
return f.read()
# === BPE TRAINING ===
def get_pair_counts(token_ids: list[int]) -> Counter:
"""Count frequency of every adjacent token pair.
For sequence s = [s_0, s_1, ..., s_n], we count all (s_i, s_{i+1}) pairs.
Example: [a, b, c, b, c] -> {(a,b): 1, (b,c): 2, (c,b): 1}.
This is the core statistic BPE uses to decide what to merge next.
"""
# zip(ids, ids[1:]) pairs each element with its right neighbor -- O(n).
return Counter(zip(token_ids, token_ids[1:]))
def apply_merge(token_ids: list[int], pair: tuple[int, int], new_id: int) -> list[int]:
"""Replace every occurrence of `pair` with `new_id` in a single left-to-right pass.
Overlapping pairs resolve left-to-right: in [a, a, a] merging (a,a) produces
[new, a], not [a, new]. This matches the standard BPE convention and ensures
the merge operation is deterministic regardless of pair overlap patterns.
"""
# Signpost: this O(n) scan runs once per merge, giving O(n * M) total training
# cost for M merges. Production implementations (SentencePiece, tiktoken) use
# priority queues for O(n log n) total, but the output is identical.
merged = []
i = 0
while i < len(token_ids):
if i < len(token_ids) - 1 and (token_ids[i], token_ids[i + 1]) == pair:
merged.append(new_id)
i += 2 # consumed both tokens in the pair
else:
merged.append(token_ids[i])
i += 1
return merged
def train_bpe(
token_ids: list[int], num_merges: int
) -> list[tuple[tuple[int, int], int]]:
"""Learn BPE merge rules by greedily merging the most frequent adjacent pair.
Each merge absorbs the single most redundant pair in the corpus -- a greedy
compression step that naturally discovers morphological units ("an" + "a",
"el" + "la") without any linguistic rules. The merge table is ordered by
priority: merge 0 was most frequent in the original corpus, merge 1 most
frequent after merge 0, and so on. This ordering is critical for encoding.
Returns: ordered list of (pair, new_id) tuples where new_id = 256 + merge_index.
"""
ids = list(token_ids) # work on a copy
merges: list[tuple[tuple[int, int], int]] = []
for i in range(num_merges):
counts = get_pair_counts(ids)
if not counts:
# Entire corpus collapsed to a single token (or is empty). Rare in
# practice, but correct to handle: no more pairs means no more merges.
break
# The pair with the highest count gets merged next.
pair = max(counts, key=counts.get) # type: ignore[arg-type]
new_id = 256 + i # byte IDs 0-255 reserved; merges start at 256
ids = apply_merge(ids, pair, new_id)
merges.append((pair, new_id))
if (i + 1) % 32 == 0 or i == 0:
a, b = pair
print(
f" merge {i + 1:>3}/{num_merges}: "
f"({a:>3}, {b:>3}) -> {new_id:>3} "
f"freq={counts[pair]:>5} corpus_len={len(ids)}"
)
return merges
# === ENCODING & DECODING ===
def build_vocab(merges: list[tuple[tuple[int, int], int]]) -> dict[int, bytes]:
"""Build token ID -> bytes lookup table.
Base vocabulary: 256 entries mapping each byte value to its single-byte string.
Each merge extends the table: vocab[new_id] = vocab[a] + vocab[b].
This recursive expansion means decoding is just a table lookup -- no merge
replay needed, and round-trip correctness is guaranteed by construction.
"""
vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
for (a, b), new_id in merges:
vocab[new_id] = vocab[a] + vocab[b]
return vocab
def encode(text: str, merges: list[tuple[tuple[int, int], int]]) -> list[int]:
"""Encode a string to BPE token IDs by replaying merges in priority order.
Critical: merges are applied in the order they were learned (priority order),
NOT by re-counting frequencies on the new text. Priority order ensures
deterministic tokenization -- the same string always produces the same token
sequence, regardless of what other text the tokenizer was trained on.
Re-counting frequencies would make the output dependent on the input batch,
breaking the contract that tokenization is a pure function of the input string.
"""
# Signpost: this O(n * M) naive encoding checks every merge against the full
# sequence. Production tokenizers (tiktoken, HuggingFace) use trie structures
# for O(n) encoding, but produce identical output.
token_ids = list(text.encode("utf-8"))
for pair, new_id in merges:
token_ids = apply_merge(token_ids, pair, new_id)
return token_ids
def decode(token_ids: list[int], vocab: dict[int, bytes]) -> str:
"""Decode token IDs back to a string via byte lookup and UTF-8 decoding.
Every token maps to a definite byte sequence through the vocab table, so
decode(encode(text)) == text is guaranteed for any valid UTF-8 input.
Decoding is trivially simple by design -- all the complexity lives in encoding.
"""
raw_bytes = b"".join(vocab[tid] for tid in token_ids)
return raw_bytes.decode("utf-8")
# === INFERENCE DEMO ===
if __name__ == "__main__":
# -- Load and prepare data --
raw = load_data(DATA_URL, DATA_FILE)
corpus_ids = list(raw)
# Starting from raw bytes means every possible input is representable --
# there are no "unknown token" problems. This is the key insight of byte-level
# BPE: the base vocabulary covers all of Unicode (via UTF-8 byte sequences)
# without needing a character-level vocabulary for every writing system.
print(f"Corpus: {len(raw):,} bytes, base vocab: 256 byte tokens")
print(f"Training {NUM_MERGES} merges (final vocab: {256 + NUM_MERGES} tokens)\n")
# -- Train --
print("Training BPE...")
merges = train_bpe(corpus_ids, NUM_MERGES)
vocab = build_vocab(merges)
print(f"\nTraining complete: {len(merges)} merges learned\n")
# -- Round-trip tests --
# Verify encode-decode identity on diverse inputs: common name, uncommon name,
# hyphenated, apostrophe, empty string, single character.
test_strings = ["Emma", "Xiomara", "Mary-Jane", "O'Brien", "", "Z"]
print("Round-trip tests:")
all_pass = True
for s in test_strings:
encoded = encode(s, merges)
decoded = decode(encoded, vocab)
status = "PASS" if decoded == s else "FAIL"
if status == "FAIL":
all_pass = False
display = f'"{s}"' if s else '""'
print(f" [{status}] {display:<14} -> {len(encoded):>2} tokens -> {decoded!r}")
print()
# -- Compression ratio --
# compression_ratio = len(original_bytes) / len(bpe_tokens)
# Each BPE token represents `ratio` bytes on average. Higher is better --
# it means the tokenizer discovered more compressible structure.
corpus_text = raw.decode("utf-8")
corpus_encoded = encode(corpus_text, merges)
ratio = len(raw) / len(corpus_encoded)
print(
f"Compression: {len(raw):,} bytes -> {len(corpus_encoded):,} tokens "
f"(ratio: {ratio:.2f}x)\n"
)
# -- Top 20 merges --
print("Top 20 merges (earliest = highest priority):")
for i, ((a, b), new_id) in enumerate(merges[:20]):
a_str = vocab[a].decode("utf-8", errors="replace")
b_str = vocab[b].decode("utf-8", errors="replace")
merged_str = vocab[new_id].decode("utf-8", errors="replace")
print(f" {i + 1:>2}. {a_str!r:>6} + {b_str!r:<6} -> {merged_str!r}")
print()
# -- Tokenization example --
example = "Elizabeth"
example_tokens = encode(example, merges)
pieces = [vocab[tid].decode("utf-8", errors="replace") for tid in example_tokens]
print(f'Tokenization example: "{example}"')
print(f" Bytes: {list(example.encode('utf-8'))}")
print(f" Tokens: {example_tokens}")
print(f" Pieces: {pieces}")