-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer_ext.py
More file actions
102 lines (84 loc) · 3.04 KB
/
tokenizer_ext.py
File metadata and controls
102 lines (84 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import multiprocessing as mp
from functools import partial
# ---------- CORE FUNCTIONS (single text, distributed across cores) ----------
# global tokenizer reference inside workers
_tokenizer = None
_pad_token_id = None
def _init_worker(tokenizer, pad_token_id):
global _tokenizer, _pad_token_id
_tokenizer = tokenizer
_pad_token_id = pad_token_id
def _encode_chunk(chunk):
return _tokenizer.encode(chunk)
def _decode_chunk(tokens):
tokens = [t for t in tokens if t != _pad_token_id]
return _tokenizer.decode(tokens)
def encode_text_parallel(tokenizer, text, num_workers=None, chunk_size=256):
"""
Encode a single text by splitting into chunks and distributing across cores.
"""
# split text into chunks
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
with mp.Pool(
processes=num_workers,
initializer=_init_worker,
initargs=(tokenizer, None)
) as pool:
encoded_chunks = pool.map(_encode_chunk, chunks)
# flatten result
return [tok for chunk in encoded_chunks for tok in chunk]
def decode_text_parallel(tokenizer, tokens, pad_token_id=0, num_workers=None, chunk_size=256):
"""
Decode a single sequence of tokens in parallel, ignoring padding.
"""
# split into chunks
chunks = [tokens[i:i+chunk_size] for i in range(0, len(tokens), chunk_size)]
with mp.Pool(
processes=num_workers,
initializer=_init_worker,
initargs=(tokenizer, None)
) as pool:
decoded_chunks = pool.map(_decode_chunk, chunks)
return "".join(decoded_chunks)
# ---------- BATCH FUNCTIONS (multiple texts, distributed again) ----------
def batch_encode_parallel(tokenizer, texts, pad_token_id=0, max_length=None,
num_workers=None, chunk_size=256):
"""
Encode a batch of texts using core parallel encode_text_parallel,
then pad results.
"""
with mp.Pool(
processes=num_workers,
initializer=_init_worker,
initargs=(tokenizer, None)
) as pool:
sequences = pool.map(
partial(encode_text_parallel, tokenizer, num_workers=num_workers, chunk_size=chunk_size),
texts
)
# determine padding length
if max_length is None:
max_length = max(len(seq) for seq in sequences)
padded = [
seq[:max_length] + [pad_token_id] * (max_length - len(seq))
for seq in sequences
]
return padded
def batch_decode_parallel(tokenizer, batch_ids, pad_token_id=0,
num_workers=None, chunk_size=256):
"""
Decode a batch of token sequences using core parallel decode_text_parallel.
"""
with mp.Pool(
processes=num_workers,
initializer=_init_worker,
initargs=(tokenizer, None)
) as pool:
texts = pool.map(
partial(decode_text_parallel, tokenizer,
pad_token_id=pad_token_id,
num_workers=num_workers,
chunk_size=chunk_size),
batch_ids
)
return texts