Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 31fb1cf

Browse files
committed
PR comment, update _tokenizer_type to global
1 parent 889f053 commit 31fb1cf

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

dist_run.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151

5252
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5354

5455
# Using model name to identify the model to load, for example "llama2-7b-chat".
5556
# You can change it to other values listed below.
@@ -85,8 +86,11 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
8586
def _build_chat_tokenizer(
8687
model_name: str,
8788
model_base_name: Optional[str] = None,
88-
) -> tuple[SentencePieceProcessor | TiktokenTokenizer, TokenizerType]:
89-
"""Builds a tokenizer for the given model name."""
89+
) -> SentencePieceProcessor | TiktokenTokenizer:
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
9094
# Try to infer the model base name from the model name:
9195
# e.g. "llama2-7b-chat" -> "llama2"
9296
if model_base_name is None:
@@ -113,15 +117,16 @@ def _build_chat_tokenizer(
113117
logger.info(
114118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
115119
)
120+
# set global variable _tokenizer_type
116121
if isinstance(tokenizer, TiktokenTokenizer):
117-
tokenizer_type = TokenizerType.Tiktoken
122+
_tokenizer_type = TokenizerType.Tiktoken
118123
elif isinstance(tokenizer, SentencePieceProcessor):
119-
tokenizer_type = TokenizerType.SentencePiece
124+
_tokenizer_type = TokenizerType.SentencePiece
120125
else:
121126
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
122127

123-
logger.info(f"tokenizer type = {tokenizer_type}")
124-
return tokenizer, tokenizer_type
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
129+
return tokenizer
125130

126131

127132
def _load_model_weights(stage_module, distribution, device, model_config):
@@ -309,7 +314,7 @@ def main(args):
309314
config = TransformerArgs.from_params(model_config.transformer_args["text"])
310315
logger.info(f"Transformer Config: {config}")
311316

312-
tokenizer, tokenizer_type = _build_chat_tokenizer(model_name)
317+
tokenizer = _build_chat_tokenizer(model_name)
313318

314319
set_precision(model_dtype)
315320
logger.info(f"Using cache precision {model_dtype}")
@@ -554,15 +559,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
554559
# token ids. Thus cat'ing along dim 1.
555560
res = torch.cat(res, dim=1)
556561
res_list = res.tolist()
557-
if tokenizer_type == TokenizerType.Tiktoken:
562+
if _tokenizer_type == TokenizerType.Tiktoken:
558563
# For TiktokenTokenizer, we need to decode prompt by prompt.
559564
# TODO: is there a better way to do this?
560565
responses = [tokenizer.decode(sequence) for sequence in res_list]
561-
elif tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
566+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
562567
# For SentencePieceProcessor, we can decode the entire 2D list at once.
563568
responses = tokenizer.decode(res_list)
564569
else:
565-
raise ValueError(f"Unknown tokenizer type {tokenizer_type}")
570+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
566571

567572
# Show prompts and responses
568573
for prompt_text, response_text in zip(prompt, responses):

0 commit comments

Comments
 (0)