Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b217158

Browse files
authored
[Distributed] add TokenizerType enum, spelling (#1266)
* add TokenizerType enum, update decoder spelling * revert prompts to same length * PR comment, update _tokenizer_type to global
1 parent d0993b3 commit b217158

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

dist_run.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -49,6 +50,7 @@
4950

5051

5152
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5254

5355
# Using model name to identify the model to load, for example "llama2-7b-chat".
5456
# You can change it to other values listed below.
@@ -59,6 +61,11 @@
5961
}
6062

6163

64+
class TokenizerType(Enum):
65+
Tiktoken = auto()
66+
SentencePiece = auto()
67+
68+
6269
def _init_distributed():
6370
dist.init_process_group("nccl")
6471
rank = dist.get_rank()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087
model_name: str,
8188
model_base_name: Optional[str] = None,
8289
) -> SentencePieceProcessor | TiktokenTokenizer:
83-
"""Builds a tokenizer for the given model name."""
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
8494
# Try to infer the model base name from the model name:
8595
# e.g. "llama2-7b-chat" -> "llama2"
8696
if model_base_name is None:
@@ -107,6 +117,15 @@ def _build_chat_tokenizer(
107117
logger.info(
108118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109119
)
120+
# set global variable _tokenizer_type
121+
if isinstance(tokenizer, TiktokenTokenizer):
122+
_tokenizer_type = TokenizerType.Tiktoken
123+
elif isinstance(tokenizer, SentencePieceProcessor):
124+
_tokenizer_type = TokenizerType.SentencePiece
125+
else:
126+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127+
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
110129
return tokenizer
111130

112131

@@ -269,6 +288,7 @@ def _cleanup():
269288

270289
prompt = [
271290
"What is Snow?",
291+
# "Can you explain what is the purpose of back propagation in neural networks?",
272292
"Who is Santa Claus?",
273293
"Where does Santa live?",
274294
# "Who is Abraham Lincoln?",
@@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487507
group=pp_group,
488508
)
489509
# create schedule
490-
decorder = ScheduleGPipe(decode_stage, 1)
510+
decoder = ScheduleGPipe(decode_stage, 1)
491511

492512
# Decoding
493513
with torch.no_grad(), CUDATrackTime() as timer:
@@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510530

511531
# Run data through pipeline
512532
if pp_rank == first_pp_rank:
513-
output = decorder.step(new_token, **kwargs)
533+
output = decoder.step(new_token, **kwargs)
514534
elif pp_rank == last_pp_rank:
515-
output = decorder.step(**kwargs)
535+
output = decoder.step(**kwargs)
516536
else: # middle pp ranks
517-
decorder.step(**kwargs)
537+
decoder.step(**kwargs)
518538

519539
# Decode the output
520540
if pp_rank == last_pp_rank:
@@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539559
# token ids. Thus cat'ing along dim 1.
540560
res = torch.cat(res, dim=1)
541561
res_list = res.tolist()
542-
if isinstance(tokenizer, TiktokenTokenizer):
562+
if _tokenizer_type == TokenizerType.Tiktoken:
543563
# For TiktokenTokenizer, we need to decode prompt by prompt.
544564
# TODO: is there a better way to do this?
545565
responses = [tokenizer.decode(sequence) for sequence in res_list]
546-
else: # SentencePieceProcessor
566+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
547567
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548568
responses = tokenizer.decode(res_list)
569+
else:
570+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
571+
549572
# Show prompts and responses
550573
for prompt_text, response_text in zip(prompt, responses):
551574
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

0 commit comments

Comments
 (0)