5050
5151
5252logger = SingletonLogger .get_logger ()
53+ _tokenizer_type = None # global variable to store the tokenizer type
5354
5455# Using model name to identify the model to load, for example "llama2-7b-chat".
5556# You can change it to other values listed below.
@@ -85,8 +86,11 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
8586def _build_chat_tokenizer (
8687 model_name : str ,
8788 model_base_name : Optional [str ] = None ,
88- ) -> tuple [SentencePieceProcessor | TiktokenTokenizer , TokenizerType ]:
89- """Builds a tokenizer for the given model name."""
89+ ) -> SentencePieceProcessor | TiktokenTokenizer :
90+ """Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+ global _tokenizer_type
93+
9094 # Try to infer the model base name from the model name:
9195 # e.g. "llama2-7b-chat" -> "llama2"
9296 if model_base_name is None :
@@ -113,15 +117,16 @@ def _build_chat_tokenizer(
113117 logger .info (
114118 f"using tokenizer = { tokenizer .__class__ .__module__ } .{ tokenizer .__class__ .__name__ } "
115119 )
120+ # set global variable _tokenizer_type
116121 if isinstance (tokenizer , TiktokenTokenizer ):
117- tokenizer_type = TokenizerType .Tiktoken
122+ _tokenizer_type = TokenizerType .Tiktoken
118123 elif isinstance (tokenizer , SentencePieceProcessor ):
119- tokenizer_type = TokenizerType .SentencePiece
124+ _tokenizer_type = TokenizerType .SentencePiece
120125 else :
121126 raise ValueError (f"Unknown tokenizer type: { tokenizer .__class__ } " )
122127
123- logger .info (f"tokenizer type = { tokenizer_type } " )
124- return tokenizer , tokenizer_type
128+ logger .info (f"tokenizer type = { _tokenizer_type } " )
129+ return tokenizer
125130
126131
127132def _load_model_weights (stage_module , distribution , device , model_config ):
@@ -309,7 +314,7 @@ def main(args):
309314 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
310315 logger .info (f"Transformer Config: { config } " )
311316
312- tokenizer , tokenizer_type = _build_chat_tokenizer (model_name )
317+ tokenizer = _build_chat_tokenizer (model_name )
313318
314319 set_precision (model_dtype )
315320 logger .info (f"Using cache precision { model_dtype } " )
@@ -554,15 +559,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
554559 # token ids. Thus cat'ing along dim 1.
555560 res = torch .cat (res , dim = 1 )
556561 res_list = res .tolist ()
557- if tokenizer_type == TokenizerType .Tiktoken :
562+ if _tokenizer_type == TokenizerType .Tiktoken :
558563 # For TiktokenTokenizer, we need to decode prompt by prompt.
559564 # TODO: is there a better way to do this?
560565 responses = [tokenizer .decode (sequence ) for sequence in res_list ]
561- elif tokenizer_type == TokenizerType .SentencePiece : # SentencePieceProcessor
566+ elif _tokenizer_type == TokenizerType .SentencePiece : # SentencePieceProcessor
562567 # For SentencePieceProcessor, we can decode the entire 2D list at once.
563568 responses = tokenizer .decode (res_list )
564569 else :
565- raise ValueError (f"Unknown tokenizer type { tokenizer_type } " )
570+ raise ValueError (f"Unknown tokenizer type { _tokenizer_type } " )
566571
567572 # Show prompts and responses
568573 for prompt_text , response_text in zip (prompt , responses ):
0 commit comments