|
24 | 24 | # TODO - these are not distributed specific, consider moving to new package |
25 | 25 | from torchchat.distributed.safetensor_utils import ( |
26 | 26 | get_hf_config_file, |
27 | | - get_hf_weight_map_and_path, |
28 | | - load_safetensor_weights, |
| 27 | + load_weights_from_hf_format, |
29 | 28 | ) |
30 | 29 | from torchchat.distributed.utils import ( |
31 | 30 | bytes_to_readable, |
|
58 | 57 | "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), |
59 | 58 | } |
60 | 59 |
|
| 60 | +# This format stands for: index file + multiple safetensor. |
| 61 | +USE_HF_CHECKPOINT_FORMAT = True |
| 62 | +# TODO: add support for single bin format. |
| 63 | + |
61 | 64 |
|
62 | 65 | def _init_distributed(): |
63 | 66 | dist.init_process_group("nccl") |
@@ -114,22 +117,10 @@ def _load_model_weights(stage_module, distribution, device, model_config): |
114 | 117 | """Load the weights from the safetensor file(s) into the model stage. |
115 | 118 | Model config is needed b/c we permute wq and wk weights based on attn heads. |
116 | 119 | """ |
117 | | - |
118 | | - weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) |
119 | | - |
120 | | - num_loaded_weights, num_missing_weights = load_safetensor_weights( |
121 | | - stage_module, |
122 | | - weight_map, |
123 | | - weight_path, |
124 | | - key_map, |
125 | | - device, |
126 | | - model_config=model_config, |
127 | | - ) |
128 | | - logger.info( |
129 | | - f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" |
130 | | - ) |
131 | | - if num_missing_weights > 0: |
132 | | - raise ValueError(f"Missing {num_missing_weights} weights") |
| 120 | + if USE_HF_CHECKPOINT_FORMAT: |
| 121 | + load_weights_from_hf_format(stage_module, distribution, device, model_config) |
| 122 | + else: |
| 123 | + load_weights_from_torchchat_format(stage_module, distribution, device, model_config) |
133 | 124 |
|
134 | 125 |
|
135 | 126 | def _encode_strings( |
|
0 commit comments