@@ -102,32 +102,11 @@ def decode(
102102
103103
104104def _build_chat_tokenizer (
105- model_name : str ,
106- model_base_name : Optional [str ] = None ,
105+ tokenizer_args : TokenizerArgs ,
107106) -> SentencePieceProcessor | TiktokenTokenizer :
108107 """Builds a tokenizer for the given model name"""
109-
110- # Try to infer the model base name from the model name:
111- # e.g. "llama2-7b-chat" -> "llama2"
112- if model_base_name is None :
113- model_base_name = model_name .split ("-" )[0 ]
114- logger .info (
115- f"Using model base name '{ model_base_name } ' to build tokenizer. "
116- "If not found, please specify it using the `model_base_name` argument."
117- )
118-
119- # Create base args for tokenizer
120- default_model_dir = Path (
121- os .getenv ("TORCHCHAT_MODELDIR" , "~/.torchchat/model-cache" )
122- ).expanduser ()
123-
124- tokenconfig = {
125- "model_directory" : default_model_dir ,
126- "model" : model_base_name ,
127- "tokenizer_path" : None ,
128- }
129- args = dict_to_args (tokenconfig )
130- tokenizer_args = TokenizerArgs .from_args (args )
108+
109+ tokenizer_args = TokenizerArgs .from_args (tokenizer_args )
131110 tokenizer = tokenizer_args .t
132111 assert tokenizer is not None , f"Failed to get tokenizer using { tokenconfig = } "
133112 logger .info (
@@ -313,9 +292,14 @@ def _cleanup():
313292]
314293
315294
316- def main (args , pipe ):
295+ def main (
296+ builder_args ,
297+ tokenizer_args ,
298+ pipe ,
299+ ):
317300 model_name = "llama3" # args.model_name
318- pp_degree = args .pp
301+ # print(f"{builder_args.checkpoint_path=}")
302+ pp_degree = builder_args .pp
319303
320304 rank , world_size = _init_distributed ()
321305 logger .info (f"Worker started: { rank = } , { world_size = } " )
@@ -332,7 +316,7 @@ def main(args, pipe):
332316 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
333317 logger .info (f"Transformer Config: { config } " )
334318
335- tokenizer = _build_chat_tokenizer (model_name )
319+ tokenizer = _build_chat_tokenizer (tokenizer_args )
336320
337321 set_precision (model_dtype )
338322 logger .info (f"Using cache precision { model_dtype } " )
@@ -385,7 +369,7 @@ def main(args, pipe):
385369 # Load weights
386370 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
387371 with CUDATrackTime () as timer :
388- _load_model_weights (model , distribution , device , config , args .chpt_from )
372+ _load_model_weights (model , distribution , device , config , builder_args .chpt_from )
389373
390374 logger .info (
391375 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
0 commit comments