Skip to content

Commit 5ee9835

Browse files
committed
Fix potential race condition with multithreaded sampling and lazy tokenizer initialization
1 parent 1df7b04 commit 5ee9835

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

exllamav2/tokenizer/tokenizer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
ExLlamaV2TokenizerSPM,
99
ExLlamaV2TokenizerHF
1010
)
11+
import threading
12+
13+
14+
lock = threading.RLock()
15+
def synchronized_init(func):
16+
def wrapper(*args, **kwargs):
17+
with lock:
18+
return func(*args, **kwargs)
19+
return wrapper
20+
1121

1222
class ExLlamaV2Tokenizer:
1323

@@ -20,7 +30,6 @@ def __init__(self, children = None, leaf = None):
2030
self.children = children if children is not None else {}
2131
self.leaf = leaf if leaf is not None else []
2232

23-
2433
config: ExLlamaV2Config
2534

2635
tokenizer_model: ExLlamaV2TokenizerBase
@@ -567,8 +576,8 @@ def num_tokens(self, text):
567576

568577
# Get ordinals of single-byte tokens
569578

579+
@synchronized_init
570580
def get_id_to_ord_list(self):
571-
572581
if self.id_to_ord is not None: return self.id_to_ord
573582

574583
self.id_to_ord = []
@@ -594,6 +603,7 @@ def get_id_to_ord_list(self):
594603

595604
# Copy vocabulary from model
596605

606+
@synchronized_init
597607
def get_id_to_piece_list(self, include_special_tokens = False):
598608

599609
if include_special_tokens:
@@ -633,6 +643,7 @@ def get_id_to_piece_list(self, include_special_tokens = False):
633643
return self.id_to_piece
634644

635645

646+
@synchronized_init
636647
def get_piece_to_id_dict(self):
637648

638649
if self.piece_to_id is not None: return self.piece_to_id
@@ -644,6 +655,7 @@ def get_piece_to_id_dict(self):
644655

645656
# Create dictionary mapping prefixes to token IDs
646657

658+
@synchronized_init
647659
def get_prefix_to_ids_dict(self):
648660

649661
if self.prefix_to_ids is not None: return self.prefix_to_ids
@@ -671,6 +683,7 @@ def get_prefix_to_ids_dict(self):
671683

672684
# Create dictionary mapping each ID to any IDs that it prefixes
673685

686+
@synchronized_init
674687
def get_prefix_id_to_ids_dict(self):
675688

676689
if self.prefix_id_to_ids is not None: return self.prefix_id_to_ids
@@ -712,6 +725,7 @@ def _make_trie(self, ci):
712725
return trie
713726

714727

728+
@synchronized_init
715729
def get_char_trie(self):
716730

717731
if self.char_trie is not None: return self.char_trie
@@ -720,6 +734,7 @@ def get_char_trie(self):
720734
return self.char_trie
721735

722736

737+
@synchronized_init
723738
def get_char_trie_ci(self):
724739

725740
if self.char_trie_ci is not None: return self.char_trie_ci

0 commit comments

Comments
 (0)