1010
1111import argparse
1212import os
13+ from enum import auto , Enum
1314from pathlib import Path
1415from types import SimpleNamespace
1516from typing import Any , Dict , List , Optional , Tuple
5960}
6061
6162
63+ class TokenizerType (Enum ):
64+ Tiktoken = auto ()
65+ SentencePiece = auto ()
66+
67+
6268def _init_distributed ():
6369 dist .init_process_group ("nccl" )
6470 rank = dist .get_rank ()
@@ -79,7 +85,7 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7985def _build_chat_tokenizer (
8086 model_name : str ,
8187 model_base_name : Optional [str ] = None ,
82- ) -> SentencePieceProcessor | TiktokenTokenizer :
88+ ) -> tuple [ SentencePieceProcessor | TiktokenTokenizer , TokenizerType ] :
8389 """Builds a tokenizer for the given model name."""
8490 # Try to infer the model base name from the model name:
8591 # e.g. "llama2-7b-chat" -> "llama2"
@@ -107,7 +113,15 @@ def _build_chat_tokenizer(
107113 logger .info (
108114 f"using tokenizer = { tokenizer .__class__ .__module__ } .{ tokenizer .__class__ .__name__ } "
109115 )
110- return tokenizer
116+ if isinstance (tokenizer , TiktokenTokenizer ):
117+ tokenizer_type = TokenizerType .Tiktoken
118+ elif isinstance (tokenizer , SentencePieceProcessor ):
119+ tokenizer_type = TokenizerType .SentencePiece
120+ else :
121+ raise ValueError (f"Unknown tokenizer type: { tokenizer .__class__ } " )
122+
123+ logger .info (f"tokenizer type = { tokenizer_type } " )
124+ return tokenizer , tokenizer_type
111125
112126
113127def _load_model_weights (stage_module , distribution , device , model_config ):
@@ -269,8 +283,9 @@ def _cleanup():
269283
270284prompt = [
271285 "What is Snow?" ,
272- "Who is Santa Claus?" ,
273- "Where does Santa live?" ,
286+ "Can you explain what is the purpose of back propagation in neural networks?" ,
287+ # "Who is Santa Claus?",
288+ # "Where does Santa live?",
274289 # "Who is Abraham Lincoln?",
275290 # "How are models trained?",
276291]
@@ -294,7 +309,7 @@ def main(args):
294309 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
295310 logger .info (f"Transformer Config: { config } " )
296311
297- tokenizer = _build_chat_tokenizer (model_name )
312+ tokenizer , tokenizer_type = _build_chat_tokenizer (model_name )
298313
299314 set_precision (model_dtype )
300315 logger .info (f"Using cache precision { model_dtype } " )
@@ -487,7 +502,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487502 group = pp_group ,
488503 )
489504 # create schedule
490- decorder = ScheduleGPipe (decode_stage , 1 )
505+ decoder = ScheduleGPipe (decode_stage , 1 )
491506
492507 # Decoding
493508 with torch .no_grad (), CUDATrackTime () as timer :
@@ -510,11 +525,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510525
511526 # Run data through pipeline
512527 if pp_rank == first_pp_rank :
513- output = decorder .step (new_token , ** kwargs )
528+ output = decoder .step (new_token , ** kwargs )
514529 elif pp_rank == last_pp_rank :
515- output = decorder .step (** kwargs )
530+ output = decoder .step (** kwargs )
516531 else : # middle pp ranks
517- decorder .step (** kwargs )
532+ decoder .step (** kwargs )
518533
519534 # Decode the output
520535 if pp_rank == last_pp_rank :
@@ -539,13 +554,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539554 # token ids. Thus cat'ing along dim 1.
540555 res = torch .cat (res , dim = 1 )
541556 res_list = res .tolist ()
542- if isinstance ( tokenizer , TiktokenTokenizer ) :
557+ if tokenizer_type == TokenizerType . Tiktoken :
543558 # For TiktokenTokenizer, we need to decode prompt by prompt.
544559 # TODO: is there a better way to do this?
545560 responses = [tokenizer .decode (sequence ) for sequence in res_list ]
546- else : # SentencePieceProcessor
561+ elif tokenizer_type == TokenizerType . SentencePiece : # SentencePieceProcessor
547562 # For SentencePieceProcessor, we can decode the entire 2D list at once.
548563 responses = tokenizer .decode (res_list )
564+ else :
565+ raise ValueError (f"Unknown tokenizer type { tokenizer_type } " )
566+
549567 # Show prompts and responses
550568 for prompt_text , response_text in zip (prompt , responses ):
551569 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments