-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathgenerate_tokenizer.py
More file actions
525 lines (448 loc) · 21.8 KB
/
generate_tokenizer.py
File metadata and controls
525 lines (448 loc) · 21.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
# !!! AI ARTIFACT !!!
# This file was primarily written with AI.
import os
import argparse
from datasets import load_dataset, interleave_datasets
import sentencepiece as spm
import itertools
import random
from unidecode import unidecode
# --- Dataset 1: Wikipedia ---
DATASET_NAME_WIKI = "wikipedia"
DATASET_CONFIG_WIKI = "20220301.en"
DATASET_SPLIT_WIKI = "train"
TEXT_COLUMN_WIKI = "text"
# --- Dataset 2: DailyDialog ---
DATASET_NAME_DD = "daily_dialog"
DATASET_SPLIT_DD = "train"
UTTERANCE_COLUMN_DD = "dialog"
# --- Dataset 3: BlendedSkillTalk (includes Persona-Chat) ---
DATASET_NAME_BST = "blended_skill_talk"
DATASET_SPLIT_BST = "train"
# --- Dataset 4: OpenSubtitles ---
DATASET_NAME_OS = "Helsinki-NLP/open_subtitles"
DATASET_LANG_PAIR_OS = ("en", "fr") # Load en-fr and use the 'en' part
DATASET_SPLIT_OS = "train"
TEXT_COLUMN_OS = "en" # Access via item['translation']['en']
# Reserve one space for special "empty" token.
VOCAB_SIZE = 65535
OUTPUT_DIR = "./custom_unigram_tokenizer_65k"
MODEL_PREFIX = os.path.join(OUTPUT_DIR, "unigram")
MODEL_FILE = MODEL_PREFIX + ".model"
VOCAB_FILE = MODEL_PREFIX + ".vocab"
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"
CONTROL_SYMBOLS = ["[CLS]", "[SEP]", "[MASK]"]
BATCH_SIZE = 1000
def wiki_iterator(dataset_wiki, batch_size=BATCH_SIZE):
if dataset_wiki:
for i in range(0, len(dataset_wiki), batch_size):
yield [unidecode(text) for text in dataset_wiki[i : i + batch_size][TEXT_COLUMN_WIKI] if text]
def dd_iterator(dataset_dd, batch_size=BATCH_SIZE):
if dataset_dd:
current_batch = []
for dialogue in dataset_dd:
utterances = dialogue[UTTERANCE_COLUMN_DD]
for utterance in utterances:
if utterance:
normalized_utterance = unidecode(utterance)
current_batch.append(normalized_utterance)
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
if current_batch:
yield current_batch
def bst_iterator(dataset_bst, batch_size=BATCH_SIZE):
if dataset_bst:
current_batch = []
for session in dataset_bst:
texts_to_add = []
if session.get("previous_utterance"):
texts_to_add.append(session["previous_utterance"])
if session.get("free_messages"):
texts_to_add.extend(session["free_messages"])
if session.get("guided_messages"):
texts_to_add.extend(session["guided_messages"])
for text in texts_to_add:
if text and isinstance(text, str):
normalized_text = unidecode(text)
current_batch.append(normalized_text)
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
if current_batch:
yield current_batch
def os_iterator(dataset_os, batch_size=BATCH_SIZE, lang_code=TEXT_COLUMN_OS):
if dataset_os:
current_batch = []
for item in dataset_os:
text = item['translation'][lang_code]
if text:
normalized_text = unidecode(text)
current_batch.append(normalized_text)
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
if current_batch:
yield current_batch
def count_wiki_items_chars(dataset_wiki):
item_count = 0
char_count = 0
if dataset_wiki:
item_count = len(dataset_wiki)
try:
char_count = sum(len(unidecode(text)) for text in dataset_wiki[TEXT_COLUMN_WIKI] if text and isinstance(text, str))
except Exception:
print(f"Warning: Direct column access for char count failed for {DATASET_NAME_WIKI}. Iterating row by row (slower).")
char_count = sum(len(unidecode(row[TEXT_COLUMN_WIKI])) for row in dataset_wiki if row.get(TEXT_COLUMN_WIKI) and isinstance(row[TEXT_COLUMN_WIKI], str))
return item_count, char_count
def count_dd_items_chars(dataset_dd):
item_count = 0
char_count = 0
if dataset_dd:
for dialogue in dataset_dd:
utterances = dialogue[UTTERANCE_COLUMN_DD]
for utterance in utterances:
if utterance and isinstance(utterance, str):
item_count += 1
char_count += len(unidecode(utterance))
return item_count, char_count
def count_bst_items_chars(dataset_bst):
item_count = 0
char_count = 0
if dataset_bst:
for session in dataset_bst:
texts_to_process = []
# Gather all potential text strings first
prev_utt = session.get("previous_utterance")
if prev_utt and isinstance(prev_utt, list):
if len(prev_utt) > 0 and isinstance(prev_utt[0], str):
texts_to_process.append(prev_utt[0])
elif prev_utt and isinstance(prev_utt, str):
texts_to_process.append(prev_utt)
free_msgs = session.get("free_messages")
if free_msgs and isinstance(free_msgs, list):
for item in free_msgs:
if isinstance(item, list):
texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
elif isinstance(item, str):
texts_to_process.append(item)
guided_msgs = session.get("guided_messages")
if guided_msgs and isinstance(guided_msgs, list):
for item in guided_msgs:
if isinstance(item, list):
texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
elif isinstance(item, str):
texts_to_process.append(item)
# Count items and chars from the gathered list
for text in texts_to_process:
if text and isinstance(text, str):
normalized_text = unidecode(text)
if normalized_text:
item_count += 1
char_count += len(normalized_text)
return item_count, char_count
def count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS):
item_count = 0
char_count = 0
if dataset_os:
for item in dataset_os:
text = item['translation'][lang_code]
if text and isinstance(text, str):
normalized_text = unidecode(text)
if normalized_text: # Ensure not empty after unidecode
item_count += 1
char_count += len(normalized_text)
return item_count, char_count
def load_and_count_datasets(wiki_fraction, subtitles_fraction):
"""Loads, potentially shrinks, and counts items/chars in datasets."""
datasets = {}
counts = {}
total_items = 0
total_chars = 0
# --- Wikipedia ---
print(f"Loading dataset 1: {DATASET_NAME_WIKI} ({DATASET_CONFIG_WIKI}), split: {DATASET_SPLIT_WIKI}")
dataset_wiki_full = load_dataset(DATASET_NAME_WIKI, DATASET_CONFIG_WIKI, split=DATASET_SPLIT_WIKI, trust_remote_code=True)
print(f"Original Wikipedia dataset size: {len(dataset_wiki_full):,}")
# Shrink the Wikipedia dataset
split_test_size = 1.0 - wiki_fraction
shrunk_dataset_split = dataset_wiki_full.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))
dataset_wiki = shrunk_dataset_split['train']
print(f"Using {wiki_fraction*100:.3f}% of Wikipedia dataset: {len(dataset_wiki):,} items")
count_wiki, chars_wiki = count_wiki_items_chars(dataset_wiki)
print(f"Wikipedia dataset loaded (shrunk). Precise text items: {count_wiki:,}, Characters: {chars_wiki:,}")
datasets['wiki'] = dataset_wiki
counts['wiki'] = (count_wiki, chars_wiki)
total_items += count_wiki
total_chars += chars_wiki
# --- DailyDialog ---
print(f"\nLoading dataset 2: {DATASET_NAME_DD}, split: {DATASET_SPLIT_DD}")
dataset_dd = load_dataset(DATASET_NAME_DD, split=DATASET_SPLIT_DD, trust_remote_code=True)
count_dd, chars_dd = count_dd_items_chars(dataset_dd)
print(f"DailyDialog dataset loaded. Precise text items (utterances): {count_dd:,}, Characters: {chars_dd:,}")
datasets['dd'] = dataset_dd
counts['dd'] = (count_dd, chars_dd)
total_items += count_dd
total_chars += chars_dd
# --- BlendedSkillTalk ---
print(f"\nLoading dataset 3: {DATASET_NAME_BST}, split: {DATASET_SPLIT_BST}")
dataset_bst = load_dataset(DATASET_NAME_BST, split=DATASET_SPLIT_BST, trust_remote_code=True)
count_bst, chars_bst = count_bst_items_chars(dataset_bst)
print(f"BlendedSkillTalk dataset loaded. Precise text items (extracted): {count_bst:,}, Characters: {chars_bst:,}")
datasets['bst'] = dataset_bst
counts['bst'] = (count_bst, chars_bst)
total_items += count_bst
total_chars += chars_bst
# --- OpenSubtitles ---
print(f"\nLoading dataset 4: {DATASET_NAME_OS}, lang_pair: {DATASET_LANG_PAIR_OS}, split: {DATASET_SPLIT_OS}")
# Note: OpenSubtitles can be very large. Consider streaming or specific configurations if memory is an issue.
# For now, loading a standard configuration.
dataset_os = load_dataset(DATASET_NAME_OS, lang1=DATASET_LANG_PAIR_OS[0], lang2=DATASET_LANG_PAIR_OS[1], split=DATASET_SPLIT_OS, trust_remote_code=True)
split_test_size = 1.0 - subtitles_fraction
dataset_os = dataset_os.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))['train']
count_os, chars_os = count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS)
print(f"OpenSubtitles dataset loaded. Precise text items: {count_os:,}, Characters: {chars_os:,}")
datasets['os'] = dataset_os
counts['os'] = (count_os, chars_os)
total_items += count_os
total_chars += chars_os
print(f"\nTotal precise text items from loaded datasets: {total_items:,}")
print(f"Total characters from loaded datasets: {total_chars:,}")
return datasets, counts
def train_tokenizer(model_prefix, datasets, counts):
"""Trains a Unigram tokenizer using SentencePiece and saves it."""
total_items = sum(c[0] for c in counts.values())
total_chars = sum(c[1] for c in counts.values())
if total_items == 0:
print("Error: No items found in the datasets to train on. Exiting training.")
return
print(f"\nTotal precise text items for training: {total_items:,}")
print(f"Total characters for training: {total_chars:,}")
output_dir = os.path.dirname(model_prefix)
os.makedirs(output_dir, exist_ok=True)
print(f"\nStarting SentencePiece Unigram tokenizer training with vocab size: {VOCAB_SIZE}")
print(f"Using combined dataset iterator.")
print(f"Output model prefix: {model_prefix}")
iterators_to_chain = []
def flatten_iterator(iterator):
for batch in iterator:
for item in batch:
yield item
if datasets.get('wiki') and counts['wiki'][0] > 0:
iterators_to_chain.append(flatten_iterator(wiki_iterator(datasets['wiki'])))
if datasets.get('dd') and counts['dd'][0] > 0:
iterators_to_chain.append(flatten_iterator(dd_iterator(datasets['dd'])))
if datasets.get('bst') and counts['bst'][0] > 0:
iterators_to_chain.append(flatten_iterator(bst_iterator(datasets['bst'])))
if datasets.get('os') and counts['os'][0] > 0:
iterators_to_chain.append(flatten_iterator(os_iterator(datasets['os'])))
if not iterators_to_chain:
print("Error: No valid dataset iterators available for training. Exiting.")
return
combined_iterator = itertools.chain(*iterators_to_chain)
# Include whitespace symbols so we can efficiently break lines.
# If we include the single space, it prevents the tokenizer from merging
# spaces with regular words, and tanks the average chars/token.
# This many tokens is kinda overkill, but it gives us a way to efficiently
# clear even fairly large boards, so I think it's worth.
whitespace_symbols = []
for i in range(2, 40):
whitespace_symbols.append('▁' * i)
spm.SentencePieceTrainer.train(
sentence_iterator=combined_iterator,
model_prefix=model_prefix,
vocab_size=VOCAB_SIZE,
model_type='unigram',
character_coverage=1.0,
unk_piece=UNK_TOKEN,
pad_piece=PAD_TOKEN,
control_symbols=CONTROL_SYMBOLS,
user_defined_symbols=whitespace_symbols,
# These whitespace options must be false, or else whitespace won't
# be respected when encoding.
add_dummy_prefix=False,
remove_extra_whitespaces=False,
split_by_whitespace=False,
num_threads=os.cpu_count(),
input_sentence_size=total_items,
)
print("\nTraining finished.")
print(f"SentencePiece model saved to: {model_prefix}.model")
print(f"SentencePiece vocabulary saved to: {model_prefix}.vocab")
def extract_text_samples(dataset, count, num_samples, text_extractor_func):
"""Extracts a specified number of text samples from a dataset."""
samples = []
if dataset is None or count == 0 or num_samples == 0:
return samples
# Take samples from the beginning, ensure we don't exceed dataset size
actual_samples_to_take = min(num_samples, count)
# Use the provided function to extract text correctly for this dataset type
samples = text_extractor_func(dataset.select(range(actual_samples_to_take)))
return [unidecode(s) for s in samples if s and isinstance(s,str)]
def wiki_text_extractor(dataset_slice):
"""Extracts text from a slice of the Wikipedia dataset."""
return [text for text in dataset_slice[TEXT_COLUMN_WIKI] if text]
def dd_text_extractor(dataset_slice):
"""Extracts text from a slice of the DailyDialog dataset."""
texts = []
for dialogue in dataset_slice:
texts.extend(utterance for utterance in dialogue[UTTERANCE_COLUMN_DD] if utterance)
return texts
def bst_text_extractor(dataset_slice):
"""Extracts text from a slice of the BlendedSkillTalk dataset."""
texts = []
for session in dataset_slice:
if session.get("previous_utterance"):
texts.append(session["previous_utterance"])
if session.get("free_messages"):
texts.extend(session["free_messages"])
if session.get("guided_messages"):
texts.extend(session["guided_messages"])
return [t for t in texts if t and isinstance(t,str)] # Ensure only valid strings
def os_text_extractor(dataset_slice, lang_code=TEXT_COLUMN_OS):
"""Extracts text from a slice of the OpenSubtitles dataset."""
texts = []
for item in dataset_slice:
text = item['translation'][lang_code]
if text and isinstance(text, str):
texts.append(text)
return texts
def test_tokenizer(model_path, datasets, counts, sample_size=1000):
print("\n--- Testing the trained Unigram (SentencePiece) tokenizer on data sample ---")
if sample_size <= 0:
print("Error: Sample size for testing must be positive.")
return
sp = spm.SentencePieceProcessor()
sp.load(model_path)
print(f"Successfully loaded SentencePiece model from: {model_path}")
print(f"Vocabulary size: {sp.get_piece_size()}")
# Identify available datasets with content
available_datasets = {
name: data
for name, data in datasets.items()
if counts[name][0] > 0
}
num_available_datasets = len(available_datasets)
if num_available_datasets == 0:
print("Warning: No data available in loaded datasets to sample for testing.")
return
print(f"Found {num_available_datasets} non-empty dataset(s) for testing.")
print(f"Attempting to sample up to {sample_size} total items equally from these datasets...")
# Calculate equal number of samples per available dataset
samples_per_dataset = sample_size // num_available_datasets
print(f"Targeting approximately {samples_per_dataset} samples per dataset.")
test_samples = []
dataset_extractors = {
'wiki': wiki_text_extractor,
'dd': dd_text_extractor,
'bst': bst_text_extractor,
'os': os_text_extractor,
}
actual_samples_collected = {}
# Sample equally from each available dataset
for name, dataset in available_datasets.items():
count = counts[name][0]
extractor = dataset_extractors[name]
# Determine the target number of samples for *this* dataset
num_samples_to_target = min(samples_per_dataset, count)
if num_samples_to_target <= 0:
print(f" Skipping '{name}' (no items requested or available).")
actual_samples_collected[name] = 0
continue
print(f" Sampling from '{name}' (target: {num_samples_to_target} items)...")
items_before = len(test_samples)
# For OpenSubtitles, pass the lang_code if the extractor needs it
if name == 'os':
extracted_items = extract_text_samples(dataset, count, num_samples_to_target, lambda ds_slice: extractor(ds_slice, lang_code=TEXT_COLUMN_OS))
else:
extracted_items = extract_text_samples(dataset, count, num_samples_to_target, extractor)
# Limit the extracted items to the target number
final_samples_for_dataset = extracted_items[:num_samples_to_target]
test_samples.extend(final_samples_for_dataset)
items_added = len(final_samples_for_dataset)
actual_samples_collected[name] = items_added
print(f" Added {items_added} items from '{name}'.")
actual_sample_size = len(test_samples)
if actual_sample_size == 0:
print("\nCould not gather any samples for testing.")
print("--- Test finished ---")
return
print(f"\n--- Starting Test Run ---")
print(f"Testing on {actual_sample_size} sampled text items (final count).")
print(f"Samples breakdown: {actual_samples_collected}")
total_chars = 0
total_tokens = 0
examples_to_show = 5
examples_shown = 0
random.shuffle(test_samples)
for i, text_sample in enumerate(test_samples):
if not text_sample: # Should already be filtered by extractor, but double-check
continue
try:
tokens = sp.encode_as_pieces(text_sample)
num_tokens = len(tokens)
num_chars = len(text_sample)
total_tokens += num_tokens
total_chars += num_chars
if examples_shown < examples_to_show:
print(f"\nSample {examples_shown + 1} (Overall index {i}):") # Clarify sample index
print(f" Text ({num_chars} chars): {text_sample[:100]}{'...' if len(text_sample)>100 else ''}")
print(f" Tokens ({num_tokens}): {tokens}")
examples_shown += 1
except Exception as e:
print(f"\nWarning: Error encoding sample (Overall index {i}) with SentencePiece. Skipping sample. Error: {e}")
print(f"Problematic sample text: {text_sample[:100]}...") # Show problematic text
# --- Test Summary ---
if total_tokens > 0 and total_chars > 0 : # Avoid division by zero
avg_chars_per_token = total_chars / total_tokens
print(f"\n--- Test Summary ---")
print(f"Tested on {actual_sample_size} text items.")
print(f"Final samples breakdown: {actual_samples_collected}")
print(f"Total characters in final sample: {total_chars:,}")
print(f"Total tokens in final sample: {total_tokens:,}")
print(f"Average characters per token: {avg_chars_per_token:.4f}")
elif actual_sample_size > 0:
print("\n--- Test Summary ---")
print(f"Tested on {actual_sample_size} text items.")
print(f"Final samples breakdown: {actual_samples_collected}")
print("No valid tokens were generated from the sample data (or samples were empty after unidecode).")
else:
# This case should be caught earlier, but included for completeness
print("\n--- Test Summary ---")
print("No samples were collected or processed.")
print("--- Test finished ---")
def parse_args():
parser = argparse.ArgumentParser(description="Train and test a Unigram tokenizer on combined datasets.")
parser.add_argument("--train", action="store_true", help="Train the tokenizer on datasets")
parser.add_argument("--test", action="store_true", help="Test the trained tokenizer")
parser.add_argument("--sample_size", type=int, default=1000, help="Number of items to sample for testing")
parser.add_argument("--wiki_fraction", type=float, default=0.05, help="Fraction of Wikipedia dataset to use (e.g., 0.05 for 5%)")
parser.add_argument("--subtitles_fraction", type=float,
default=0.05, help="Fraction of opensubtitles dataset to use (e.g., 0.05 for 5%)")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Load datasets and calculate counts once
print("--- Loading and Counting Datasets ---")
datasets, counts = load_and_count_datasets(wiki_fraction=args.wiki_fraction, subtitles_fraction=args.subtitles_fraction)
print("--- Dataset Loading Finished ---")
run_train = args.train
run_test = args.test
# If no arguments provided, run both by default
if not args.train and not args.test:
run_train = True
run_test = True
if run_train:
print("\n--- Starting Tokenizer Training ---")
train_tokenizer(MODEL_PREFIX, datasets, counts)
print("--- Training Finished ---")
if run_test:
# Ensure tokenizer file exists if training didn't just run
if not os.path.exists(MODEL_FILE):
print(f"\nError: SentencePiece model file {MODEL_FILE} not found. Cannot run test.")
print("Please run with --train first or ensure the file exists.")
else:
print("\n--- Starting Tokenizer Testing ---")
test_tokenizer(MODEL_FILE, datasets, counts, sample_size=args.sample_size)
print("--- Testing Finished ---")
print("\nScript finished.")