Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 33d8d71

Browse files
committed
Create load path from HF format
1 parent 766bee9 commit 33d8d71

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

dist_run.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
# TODO - these are not distributed specific, consider moving to new package
2525
from torchchat.distributed.safetensor_utils import (
2626
get_hf_config_file,
27-
get_hf_weight_map_and_path,
28-
load_safetensor_weights,
27+
load_weights_from_hf_format,
2928
)
3029
from torchchat.distributed.utils import (
3130
bytes_to_readable,
@@ -58,6 +57,10 @@
5857
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5958
}
6059

60+
# This format stands for: index file + multiple safetensor.
61+
USE_HF_CHECKPOINT_FORMAT = True
62+
# TODO: add support for single bin format.
63+
6164

6265
def _init_distributed():
6366
dist.init_process_group("nccl")
@@ -114,22 +117,10 @@ def _load_model_weights(stage_module, distribution, device, model_config):
114117
"""Load the weights from the safetensor file(s) into the model stage.
115118
Model config is needed b/c we permute wq and wk weights based on attn heads.
116119
"""
117-
118-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
119-
120-
num_loaded_weights, num_missing_weights = load_safetensor_weights(
121-
stage_module,
122-
weight_map,
123-
weight_path,
124-
key_map,
125-
device,
126-
model_config=model_config,
127-
)
128-
logger.info(
129-
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
130-
)
131-
if num_missing_weights > 0:
132-
raise ValueError(f"Missing {num_missing_weights} weights")
120+
if USE_HF_CHECKPOINT_FORMAT:
121+
load_weights_from_hf_format(stage_module, distribution, device, model_config)
122+
else:
123+
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
133124

134125

135126
def _encode_strings(

torchchat/distributed/safetensor_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,27 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
366366
else:
367367
logger.info("Fully updated state dict.")
368368
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
369+
370+
371+
def load_weights_from_hf_format(stage_module, distribution, device, model_config):
372+
"""
373+
Load the weights from Hugging Face format (index file + multiple safetensor
374+
files), and fill into `stage_module`. Model config is needed b/c we permute
375+
wq and wk weights based on attn heads.
376+
"""
377+
378+
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
379+
380+
num_loaded_weights, num_missing_weights = load_safetensor_weights(
381+
stage_module,
382+
weight_map,
383+
weight_path,
384+
key_map,
385+
device,
386+
model_config=model_config,
387+
)
388+
logger.info(
389+
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
390+
)
391+
if num_missing_weights > 0:
392+
raise ValueError(f"Missing {num_missing_weights} weights")

0 commit comments

Comments
 (0)