-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
40 lines (33 loc) · 1.3 KB
/
tokenizer.py
File metadata and controls
40 lines (33 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
import glob, os, sys
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Sequence([Whitespace(), Punctuation()])
trainer = BpeTrainer(
vocab_size=20000,
special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]", "<|endofbook|>"],
)
def trainTokenizer():
dataDir = "data/small"
files = glob.glob(os.path.join(dataDir, "*.txt"))
print(f"Training tokenizer with {files} files")
tokenizer.train(files, trainer)
path = "data/tokenizer.json"
tokenizer.save(path)
print(f"Save tokenizer to {path}")
def testTokenizer():
path = "data/tokenizer.json"
tokenizer = Tokenizer.from_file(path)
sentences = ["杨过和小龙女在古墓。", "神雕大侠,为国为民。", "华山论剑!", "pretty", "黄蓉"]
for sentence in sentences:
encoded = tokenizer.encode(sentence)
print(f"@@ Encoded: '{encoded.tokens}'")
decoded_sentence = tokenizer.decode(encoded.ids)
print(f"@@ Decoded: '{decoded_sentence}'")
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "train":
trainTokenizer()
else:
testTokenizer()