1010
1111import argparse
1212import os
13+ from enum import auto , Enum
1314from pathlib import Path
1415from types import SimpleNamespace
1516from typing import Any , Dict , List , Optional , Tuple
2223from torchchat .distributed .logging_utils import SingletonLogger
2324
2425# TODO - these are not distributed specific, consider moving to new package
25- from torchchat .distributed .safetensor_utils import (
26+ from torchchat .distributed .checkpoint_utils import (
2627 get_hf_config_file ,
27- get_hf_weight_map_and_path ,
28- load_safetensor_weights ,
28+ load_weights_from_hf_format ,
29+ load_weights_from_torchchat_format ,
2930)
3031from torchchat .distributed .utils import (
3132 bytes_to_readable ,
4950
5051
5152logger = SingletonLogger .get_logger ()
53+ _tokenizer_type = None # global variable to store the tokenizer type
5254
5355# Using model name to identify the model to load, for example "llama2-7b-chat".
5456# You can change it to other values listed below.
5961}
6062
6163
64+ class TokenizerType (Enum ):
65+ Tiktoken = auto ()
66+ SentencePiece = auto ()
67+
68+
6269def _init_distributed ():
6370 dist .init_process_group ("nccl" )
6471 rank = dist .get_rank ()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087 model_name : str ,
8188 model_base_name : Optional [str ] = None ,
8289) -> SentencePieceProcessor | TiktokenTokenizer :
83- """Builds a tokenizer for the given model name."""
90+ """Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+ global _tokenizer_type
93+
8494 # Try to infer the model base name from the model name:
8595 # e.g. "llama2-7b-chat" -> "llama2"
8696 if model_base_name is None :
@@ -107,29 +117,45 @@ def _build_chat_tokenizer(
107117 logger .info (
108118 f"using tokenizer = { tokenizer .__class__ .__module__ } .{ tokenizer .__class__ .__name__ } "
109119 )
120+ # set global variable _tokenizer_type
121+ if isinstance (tokenizer , TiktokenTokenizer ):
122+ _tokenizer_type = TokenizerType .Tiktoken
123+ elif isinstance (tokenizer , SentencePieceProcessor ):
124+ _tokenizer_type = TokenizerType .SentencePiece
125+ else :
126+ raise ValueError (f"Unknown tokenizer type: { tokenizer .__class__ } " )
127+
128+ logger .info (f"tokenizer type = { _tokenizer_type } " )
110129 return tokenizer
111130
112131
113- def _load_model_weights (stage_module , distribution , device , model_config ):
132+ def _load_model_weights (
133+ stage_module : torch .nn .Module ,
134+ distribution : str ,
135+ device : torch .device ,
136+ model_config : ModelArgs ,
137+ chpt_from : str ,
138+ ):
114139 """Load the weights from the safetensor file(s) into the model stage.
115140 Model config is needed b/c we permute wq and wk weights based on attn heads.
116- """
117141
118- weight_map , weight_path , key_map = get_hf_weight_map_and_path (distribution )
119-
120- num_loaded_weights , num_missing_weights = load_safetensor_weights (
121- stage_module ,
122- weight_map ,
123- weight_path ,
124- key_map ,
125- device ,
126- model_config = model_config ,
127- )
128- logger .info (
129- f"Success - Loaded { num_loaded_weights } weights, { num_missing_weights } missing weights"
130- )
131- if num_missing_weights > 0 :
132- raise ValueError (f"Missing { num_missing_weights } weights" )
142+ Args:
143+ stage_module (torch.nn.Module): The model stage to load the weights into.
144+ distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
145+ device (torch.device): The device to load the weights onto.
146+ model_config (ModelArgs): The model config.
147+ chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
148+ """
149+ if chpt_from == "hf" :
150+ # This format stands for: index file + multiple binary files
151+ load_weights_from_hf_format (stage_module , distribution , device , model_config )
152+ elif chpt_from == "torchchat" :
153+ # This format stands for:
154+ # single binary file, OR
155+ # multiple binary files without index files.
156+ load_weights_from_torchchat_format (stage_module , distribution , device , model_config )
157+ else :
158+ raise ValueError (f"Unknown checkpoint format: { chpt_from } " )
133159
134160
135161def _encode_strings (
@@ -269,6 +295,7 @@ def _cleanup():
269295
270296prompt = [
271297 "What is Snow?" ,
298+ # "Can you explain what is the purpose of back propagation in neural networks?",
272299 "Who is Santa Claus?" ,
273300 "Where does Santa live?" ,
274301 # "Who is Abraham Lincoln?",
@@ -286,7 +313,7 @@ def main(args):
286313 logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
287314
288315 distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
289- logger .info (f"Using HF model weights from { distribution } and dtype { model_dtype } " )
316+ logger .info (f"Using model weights from { distribution } and dtype { model_dtype } " )
290317
291318 # Model-level config
292319 model_config = ModelArgs .from_name (distribution )
@@ -348,7 +375,7 @@ def main(args):
348375 # Load weights
349376 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
350377 with CUDATrackTime () as timer :
351- _load_model_weights (model , distribution , device = device , model_config = config )
378+ _load_model_weights (model , distribution , device , config , args . chpt_from )
352379
353380 logger .info (
354381 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -487,7 +514,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487514 group = pp_group ,
488515 )
489516 # create schedule
490- decorder = ScheduleGPipe (decode_stage , 1 )
517+ decoder = ScheduleGPipe (decode_stage , 1 )
491518
492519 # Decoding
493520 with torch .no_grad (), CUDATrackTime () as timer :
@@ -510,11 +537,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510537
511538 # Run data through pipeline
512539 if pp_rank == first_pp_rank :
513- output = decorder .step (new_token , ** kwargs )
540+ output = decoder .step (new_token , ** kwargs )
514541 elif pp_rank == last_pp_rank :
515- output = decorder .step (** kwargs )
542+ output = decoder .step (** kwargs )
516543 else : # middle pp ranks
517- decorder .step (** kwargs )
544+ decoder .step (** kwargs )
518545
519546 # Decode the output
520547 if pp_rank == last_pp_rank :
@@ -539,13 +566,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539566 # token ids. Thus cat'ing along dim 1.
540567 res = torch .cat (res , dim = 1 )
541568 res_list = res .tolist ()
542- if isinstance ( tokenizer , TiktokenTokenizer ) :
569+ if _tokenizer_type == TokenizerType . Tiktoken :
543570 # For TiktokenTokenizer, we need to decode prompt by prompt.
544571 # TODO: is there a better way to do this?
545572 responses = [tokenizer .decode (sequence ) for sequence in res_list ]
546- else : # SentencePieceProcessor
573+ elif _tokenizer_type == TokenizerType . SentencePiece : # SentencePieceProcessor
547574 # For SentencePieceProcessor, we can decode the entire 2D list at once.
548575 responses = tokenizer .decode (res_list )
576+ else :
577+ raise ValueError (f"Unknown tokenizer type { _tokenizer_type } " )
578+
549579 # Show prompts and responses
550580 for prompt_text , response_text in zip (prompt , responses ):
551581 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
@@ -579,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
579609 default = False ,
580610 help = "Whether to decode token into string in flight" ,
581611 )
612+ parser .add_argument (
613+ "--chpt-from" ,
614+ type = str ,
615+ default = "hf" , # TODO: change to torchchat once we support it well
616+ help = "Checkpoint format to load from" ,
617+ choices = ["hf" , "torchchat" ],
618+ )
582619 args = parser .parse_args ()
583620
584621 main (args )
0 commit comments