Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0e3efee

Browse files
committed
add TokenizerType enum, update decoder spelling
1 parent 32241ff commit 0e3efee

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

dist_run.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -59,6 +60,11 @@
5960
}
6061

6162

63+
class TokenizerType(Enum):
64+
Tiktoken = auto()
65+
SentencePiece = auto()
66+
67+
6268
def _init_distributed():
6369
dist.init_process_group("nccl")
6470
rank = dist.get_rank()
@@ -79,7 +85,7 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7985
def _build_chat_tokenizer(
8086
model_name: str,
8187
model_base_name: Optional[str] = None,
82-
) -> SentencePieceProcessor | TiktokenTokenizer:
88+
) -> tuple[SentencePieceProcessor | TiktokenTokenizer, TokenizerType]:
8389
"""Builds a tokenizer for the given model name."""
8490
# Try to infer the model base name from the model name:
8591
# e.g. "llama2-7b-chat" -> "llama2"
@@ -107,7 +113,15 @@ def _build_chat_tokenizer(
107113
logger.info(
108114
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109115
)
110-
return tokenizer
116+
if isinstance(tokenizer, TiktokenTokenizer):
117+
tokenizer_type = TokenizerType.Tiktoken
118+
elif isinstance(tokenizer, SentencePieceProcessor):
119+
tokenizer_type = TokenizerType.SentencePiece
120+
else:
121+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
122+
123+
logger.info(f"tokenizer type = {tokenizer_type}")
124+
return tokenizer, tokenizer_type
111125

112126

113127
def _load_model_weights(stage_module, distribution, device, model_config):
@@ -269,8 +283,9 @@ def _cleanup():
269283

270284
prompt = [
271285
"What is Snow?",
272-
"Who is Santa Claus?",
273-
"Where does Santa live?",
286+
"Can you explain what is the purpose of back propagation in neural networks?",
287+
# "Who is Santa Claus?",
288+
# "Where does Santa live?",
274289
# "Who is Abraham Lincoln?",
275290
# "How are models trained?",
276291
]
@@ -294,7 +309,7 @@ def main(args):
294309
config = TransformerArgs.from_params(model_config.transformer_args["text"])
295310
logger.info(f"Transformer Config: {config}")
296311

297-
tokenizer = _build_chat_tokenizer(model_name)
312+
tokenizer, tokenizer_type = _build_chat_tokenizer(model_name)
298313

299314
set_precision(model_dtype)
300315
logger.info(f"Using cache precision {model_dtype}")
@@ -487,7 +502,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487502
group=pp_group,
488503
)
489504
# create schedule
490-
decorder = ScheduleGPipe(decode_stage, 1)
505+
decoder = ScheduleGPipe(decode_stage, 1)
491506

492507
# Decoding
493508
with torch.no_grad(), CUDATrackTime() as timer:
@@ -510,11 +525,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510525

511526
# Run data through pipeline
512527
if pp_rank == first_pp_rank:
513-
output = decorder.step(new_token, **kwargs)
528+
output = decoder.step(new_token, **kwargs)
514529
elif pp_rank == last_pp_rank:
515-
output = decorder.step(**kwargs)
530+
output = decoder.step(**kwargs)
516531
else: # middle pp ranks
517-
decorder.step(**kwargs)
532+
decoder.step(**kwargs)
518533

519534
# Decode the output
520535
if pp_rank == last_pp_rank:
@@ -539,13 +554,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539554
# token ids. Thus cat'ing along dim 1.
540555
res = torch.cat(res, dim=1)
541556
res_list = res.tolist()
542-
if isinstance(tokenizer, TiktokenTokenizer):
557+
if tokenizer_type == TokenizerType.Tiktoken:
543558
# For TiktokenTokenizer, we need to decode prompt by prompt.
544559
# TODO: is there a better way to do this?
545560
responses = [tokenizer.decode(sequence) for sequence in res_list]
546-
else: # SentencePieceProcessor
561+
elif tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
547562
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548563
responses = tokenizer.decode(res_list)
564+
else:
565+
raise ValueError(f"Unknown tokenizer type {tokenizer_type}")
566+
549567
# Show prompts and responses
550568
for prompt_text, response_text in zip(prompt, responses):
551569
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

0 commit comments

Comments
 (0)