|
23 | 23 | from torchchat.distributed.logging_utils import SingletonLogger |
24 | 24 |
|
25 | 25 | # TODO - these are not distributed specific, consider moving to new package |
26 | | -from torchchat.distributed.safetensor_utils import ( |
| 26 | +from torchchat.distributed.checkpoint_utils import ( |
27 | 27 | get_hf_config_file, |
28 | | - get_hf_weight_map_and_path, |
29 | | - load_safetensor_weights, |
| 28 | + load_weights_from_hf_format, |
| 29 | + load_weights_from_torchchat_format, |
30 | 30 | ) |
31 | 31 | from torchchat.distributed.utils import ( |
32 | 32 | bytes_to_readable, |
@@ -129,26 +129,33 @@ def _build_chat_tokenizer( |
129 | 129 | return tokenizer |
130 | 130 |
|
131 | 131 |
|
132 | | -def _load_model_weights(stage_module, distribution, device, model_config): |
| 132 | +def _load_model_weights( |
| 133 | + stage_module: torch.nn.Module, |
| 134 | + distribution: str, |
| 135 | + device: torch.device, |
| 136 | + model_config: ModelArgs, |
| 137 | + chpt_from: str, |
| 138 | +): |
133 | 139 | """Load the weights from the safetensor file(s) into the model stage. |
134 | 140 | Model config is needed b/c we permute wq and wk weights based on attn heads. |
135 | | - """ |
136 | 141 |
|
137 | | - weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) |
138 | | - |
139 | | - num_loaded_weights, num_missing_weights = load_safetensor_weights( |
140 | | - stage_module, |
141 | | - weight_map, |
142 | | - weight_path, |
143 | | - key_map, |
144 | | - device, |
145 | | - model_config=model_config, |
146 | | - ) |
147 | | - logger.info( |
148 | | - f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" |
149 | | - ) |
150 | | - if num_missing_weights > 0: |
151 | | - raise ValueError(f"Missing {num_missing_weights} weights") |
| 142 | + Args: |
| 143 | + stage_module (torch.nn.Module): The model stage to load the weights into. |
| 144 | + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". |
| 145 | + device (torch.device): The device to load the weights onto. |
| 146 | + model_config (ModelArgs): The model config. |
| 147 | + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". |
| 148 | + """ |
| 149 | + if chpt_from == "hf": |
| 150 | + # This format stands for: index file + multiple binary files |
| 151 | + load_weights_from_hf_format(stage_module, distribution, device, model_config) |
| 152 | + elif chpt_from == "torchchat": |
| 153 | + # This format stands for: |
| 154 | + # single binary file, OR |
| 155 | + # multiple binary files without index files. |
| 156 | + load_weights_from_torchchat_format(stage_module, distribution, device, model_config) |
| 157 | + else: |
| 158 | + raise ValueError(f"Unknown checkpoint format: {chpt_from}") |
152 | 159 |
|
153 | 160 |
|
154 | 161 | def _encode_strings( |
@@ -306,7 +313,7 @@ def main(args): |
306 | 313 | logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") |
307 | 314 |
|
308 | 315 | distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] |
309 | | - logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") |
| 316 | + logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") |
310 | 317 |
|
311 | 318 | # Model-level config |
312 | 319 | model_config = ModelArgs.from_name(distribution) |
@@ -368,7 +375,7 @@ def main(args): |
368 | 375 | # Load weights |
369 | 376 | logger.info(f"Loading weights for {pp_rank=} on {device=}") |
370 | 377 | with CUDATrackTime() as timer: |
371 | | - _load_model_weights(model, distribution, device=device, model_config=config) |
| 378 | + _load_model_weights(model, distribution, device, config, args.chpt_from) |
372 | 379 |
|
373 | 380 | logger.info( |
374 | 381 | f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" |
@@ -602,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: |
602 | 609 | default=False, |
603 | 610 | help="Whether to decode token into string in flight", |
604 | 611 | ) |
| 612 | + parser.add_argument( |
| 613 | + "--chpt-from", |
| 614 | + type=str, |
| 615 | + default="hf", # TODO: change to torchchat once we support it well |
| 616 | + help="Checkpoint format to load from", |
| 617 | + choices=["hf", "torchchat"], |
| 618 | + ) |
605 | 619 | args = parser.parse_args() |
606 | 620 |
|
607 | 621 | main(args) |
0 commit comments