1010
1111import argparse
1212import os
13+ from enum import auto , Enum
1314from pathlib import Path
1415from types import SimpleNamespace
1516from typing import Any , Dict , List , Optional , Tuple
4950
5051
5152logger = SingletonLogger .get_logger ()
53+ _tokenizer_type = None # global variable to store the tokenizer type
5254
5355# Using model name to identify the model to load, for example "llama2-7b-chat".
5456# You can change it to other values listed below.
5961}
6062
6163
64+ class TokenizerType (Enum ):
65+ Tiktoken = auto ()
66+ SentencePiece = auto ()
67+
68+
6269def _init_distributed ():
6370 dist .init_process_group ("nccl" )
6471 rank = dist .get_rank ()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087 model_name : str ,
8188 model_base_name : Optional [str ] = None ,
8289) -> SentencePieceProcessor | TiktokenTokenizer :
83- """Builds a tokenizer for the given model name."""
90+ """Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+ global _tokenizer_type
93+
8494 # Try to infer the model base name from the model name:
8595 # e.g. "llama2-7b-chat" -> "llama2"
8696 if model_base_name is None :
@@ -107,6 +117,15 @@ def _build_chat_tokenizer(
107117 logger .info (
108118 f"using tokenizer = { tokenizer .__class__ .__module__ } .{ tokenizer .__class__ .__name__ } "
109119 )
120+ # set global variable _tokenizer_type
121+ if isinstance (tokenizer , TiktokenTokenizer ):
122+ _tokenizer_type = TokenizerType .Tiktoken
123+ elif isinstance (tokenizer , SentencePieceProcessor ):
124+ _tokenizer_type = TokenizerType .SentencePiece
125+ else :
126+ raise ValueError (f"Unknown tokenizer type: { tokenizer .__class__ } " )
127+
128+ logger .info (f"tokenizer type = { _tokenizer_type } " )
110129 return tokenizer
111130
112131
@@ -269,6 +288,7 @@ def _cleanup():
269288
270289prompt = [
271290 "What is Snow?" ,
291+ # "Can you explain what is the purpose of back propagation in neural networks?",
272292 "Who is Santa Claus?" ,
273293 "Where does Santa live?" ,
274294 # "Who is Abraham Lincoln?",
@@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487507 group = pp_group ,
488508 )
489509 # create schedule
490- decorder = ScheduleGPipe (decode_stage , 1 )
510+ decoder = ScheduleGPipe (decode_stage , 1 )
491511
492512 # Decoding
493513 with torch .no_grad (), CUDATrackTime () as timer :
@@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510530
511531 # Run data through pipeline
512532 if pp_rank == first_pp_rank :
513- output = decorder .step (new_token , ** kwargs )
533+ output = decoder .step (new_token , ** kwargs )
514534 elif pp_rank == last_pp_rank :
515- output = decorder .step (** kwargs )
535+ output = decoder .step (** kwargs )
516536 else : # middle pp ranks
517- decorder .step (** kwargs )
537+ decoder .step (** kwargs )
518538
519539 # Decode the output
520540 if pp_rank == last_pp_rank :
@@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539559 # token ids. Thus cat'ing along dim 1.
540560 res = torch .cat (res , dim = 1 )
541561 res_list = res .tolist ()
542- if isinstance ( tokenizer , TiktokenTokenizer ) :
562+ if _tokenizer_type == TokenizerType . Tiktoken :
543563 # For TiktokenTokenizer, we need to decode prompt by prompt.
544564 # TODO: is there a better way to do this?
545565 responses = [tokenizer .decode (sequence ) for sequence in res_list ]
546- else : # SentencePieceProcessor
566+ elif _tokenizer_type == TokenizerType . SentencePiece : # SentencePieceProcessor
547567 # For SentencePieceProcessor, we can decode the entire 2D list at once.
548568 responses = tokenizer .decode (res_list )
569+ else :
570+ raise ValueError (f"Unknown tokenizer type { _tokenizer_type } " )
571+
549572 # Show prompts and responses
550573 for prompt_text , response_text in zip (prompt , responses ):
551574 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments