Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import argparse
import os
from enum import auto, Enum
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -59,6 +60,11 @@
}


class TokenizerType(Enum):
Tiktoken = auto()
SentencePiece = auto()


def _init_distributed():
dist.init_process_group("nccl")
rank = dist.get_rank()
Expand All @@ -79,7 +85,7 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
def _build_chat_tokenizer(
model_name: str,
model_base_name: Optional[str] = None,
) -> SentencePieceProcessor | TiktokenTokenizer:
) -> tuple[SentencePieceProcessor | TiktokenTokenizer, TokenizerType]:
"""Builds a tokenizer for the given model name."""
# Try to infer the model base name from the model name:
# e.g. "llama2-7b-chat" -> "llama2"
Expand Down Expand Up @@ -107,7 +113,15 @@ def _build_chat_tokenizer(
logger.info(
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
)
return tokenizer
if isinstance(tokenizer, TiktokenTokenizer):
tokenizer_type = TokenizerType.Tiktoken
elif isinstance(tokenizer, SentencePieceProcessor):
tokenizer_type = TokenizerType.SentencePiece
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")

logger.info(f"tokenizer type = {tokenizer_type}")
return tokenizer, tokenizer_type


def _load_model_weights(stage_module, distribution, device, model_config):
Expand Down Expand Up @@ -269,6 +283,7 @@ def _cleanup():

prompt = [
"What is Snow?",
# "Can you explain what is the purpose of back propagation in neural networks?",
"Who is Santa Claus?",
"Where does Santa live?",
# "Who is Abraham Lincoln?",
Expand All @@ -294,7 +309,7 @@ def main(args):
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

tokenizer = _build_chat_tokenizer(model_name)
tokenizer, tokenizer_type = _build_chat_tokenizer(model_name)

set_precision(model_dtype)
logger.info(f"Using cache precision {model_dtype}")
Expand Down Expand Up @@ -487,7 +502,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
group=pp_group,
)
# create schedule
decorder = ScheduleGPipe(decode_stage, 1)
decoder = ScheduleGPipe(decode_stage, 1)

# Decoding
with torch.no_grad(), CUDATrackTime() as timer:
Expand All @@ -510,11 +525,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decorder.step(new_token, **kwargs)
output = decoder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decorder.step(**kwargs)
output = decoder.step(**kwargs)
else: # middle pp ranks
decorder.step(**kwargs)
decoder.step(**kwargs)

# Decode the output
if pp_rank == last_pp_rank:
Expand All @@ -539,13 +554,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
if isinstance(tokenizer, TiktokenTokenizer):
if tokenizer_type == TokenizerType.Tiktoken:
# For TiktokenTokenizer, we need to decode prompt by prompt.
# TODO: is there a better way to do this?
responses = [tokenizer.decode(sequence) for sequence in res_list]
else: # SentencePieceProcessor
elif tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
# For SentencePieceProcessor, we can decode the entire 2D list at once.
responses = tokenizer.decode(res_list)
else:
raise ValueError(f"Unknown tokenizer type {tokenizer_type}")

# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
Expand Down
Loading