Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0d1e560

Browse files
committed
Add load support for torchchat checkpoint
1 parent 8990f41 commit 0d1e560

File tree

3 files changed

+127
-26
lines changed

3 files changed

+127
-26
lines changed

dist_run.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchchat.distributed.safetensor_utils import (
2626
get_hf_config_file,
2727
load_weights_from_hf_format,
28+
load_weights_from_torchchat_format,
2829
)
2930
from torchchat.distributed.utils import (
3031
bytes_to_readable,
@@ -57,10 +58,6 @@
5758
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5859
}
5960

60-
# This format stands for: index file + multiple safetensor.
61-
USE_HF_CHECKPOINT_FORMAT = True
62-
# TODO: add support for single bin format.
63-
6461

6562
def _init_distributed():
6663
dist.init_process_group("nccl")
@@ -113,14 +110,33 @@ def _build_chat_tokenizer(
113110
return tokenizer
114111

115112

116-
def _load_model_weights(stage_module, distribution, device, model_config):
113+
def _load_model_weights(
114+
stage_module: torch.nn.Module,
115+
distribution: str,
116+
device: torch.device,
117+
model_config: ModelArgs,
118+
chpt_from: str,
119+
):
117120
"""Load the weights from the safetensor file(s) into the model stage.
118121
Model config is needed b/c we permute wq and wk weights based on attn heads.
122+
123+
Args:
124+
stage_module (torch.nn.Module): The model stage to load the weights into.
125+
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
126+
device (torch.device): The device to load the weights onto.
127+
model_config (ModelArgs): The model config.
128+
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
119129
"""
120-
if USE_HF_CHECKPOINT_FORMAT:
130+
if chpt_from == "hf":
131+
# This format stands for: index file + multiple binary files
121132
load_weights_from_hf_format(stage_module, distribution, device, model_config)
122-
else:
133+
elif chpt_from == "torchchat":
134+
# This format stands for:
135+
# single binary file, OR
136+
# multiple binary files without index files.
123137
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
138+
else:
139+
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
124140

125141

126142
def _encode_strings(
@@ -277,7 +293,7 @@ def main(args):
277293
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
278294

279295
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
280-
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
296+
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
281297

282298
# Model-level config
283299
model_config = ModelArgs.from_name(distribution)
@@ -339,7 +355,7 @@ def main(args):
339355
# Load weights
340356
logger.info(f"Loading weights for {pp_rank=} on {device=}")
341357
with CUDATrackTime() as timer:
342-
_load_model_weights(model, distribution, device=device, model_config=config)
358+
_load_model_weights(model, distribution, device, config, args.chpt_from)
343359

344360
logger.info(
345361
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -570,6 +586,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
570586
default=False,
571587
help="Whether to decode token into string in flight",
572588
)
589+
parser.add_argument(
590+
"--chpt-from",
591+
type=str,
592+
default="hf", # TODO: change to torchchat once we support it well
593+
help="Checkpoint format to load from",
594+
choices=["hf", "torchchat"],
595+
)
573596
args = parser.parse_args()
574597

575598
main(args)

torchchat/cli/builder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
335335
return model
336336

337337

338-
def _load_model_default(builder_args: BuilderArgs) -> Model:
339-
assert not builder_args.gguf_path
340-
341-
model: Model = _init_model_on_meta_device(builder_args)
342-
338+
def _load_checkpoint(builder_args: BuilderArgs):
343339
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
344340
print("Loading Tune checkpoint")
345341
meta_checkpoint = torch.load(
@@ -377,6 +373,16 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
377373
mmap=True,
378374
weights_only=True,
379375
)
376+
return checkpoint
377+
378+
379+
def _load_model_default(builder_args: BuilderArgs) -> Model:
380+
assert not builder_args.gguf_path
381+
382+
model: Model = _init_model_on_meta_device(builder_args)
383+
384+
# Load checkpoint from filesystem
385+
checkpoint = _load_checkpoint(builder_args)
380386

381387
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
382388
checkpoint = checkpoint["model"]

torchchat/distributed/safetensor_utils.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
import json
1313
from torch.nn import Module
1414
from typing import Any, Dict, Tuple, Set, Optional
15+
from pathlib import Path
1516

1617
from torch.distributed._tensor import DTensor
1718
from torchchat.distributed.dtensor_utils import convert_to_dtensor
19+
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
1820

1921

2022
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
@@ -182,10 +184,10 @@ def load_safetensor_weights(
182184
update_state_dict(
183185
stage_state_dict,
184186
checkpoint,
185-
new_to_old_keymap,
186-
updated_states,
187187
device,
188-
model_config,
188+
model_config=model_config,
189+
new_to_old_keymap=new_to_old_keymap,
190+
updated_states=updated_states,
189191
)
190192
except FileNotFoundError:
191193
logger.error(f"File not found: {full_path}")
@@ -264,24 +266,36 @@ def permute_weight_to_attn_heads(w, n_heads, head_dim, model_dim):
264266
def update_state_dict(
265267
state_dict: Dict[str, torch.Tensor],
266268
checkpoint: Dict[str, torch.Tensor],
267-
new_to_old_keymap: Dict[str, str],
268-
updated_states: Set[str],
269269
device: torch.device,
270270
model_config: Optional[Dict] = None,
271+
new_to_old_keymap: Optional[Dict[str, str]] = None,
272+
updated_states: Optional[Set[str]]= None,
271273
):
274+
"""
275+
Update the state dict with the checkpoint tensors.
276+
Note:
277+
- For HF format, `new_to_old_keymap` is a mapping from the new key to the old
278+
key.
279+
- For torchchat format, `new_to_old_keymap` is None (because FQN conversion
280+
has been doen by torchchat download script).
281+
"""
272282
# for handling attn head permuting
273283
num_heads = model_config.n_heads
274284
dim = model_config.dim
275285
num_local_heads = model_config.n_local_heads
276286
head_dim = model_config.head_dim
277287

278288
for param in state_dict.keys():
279-
# TODO: clean this up together with `purge_fqn_prefix` when we switch
280-
# from creating Transformer to creating model
281-
model_param = (
282-
"output.weight" if param == "output.weight" else f"model.{param}"
283-
)
284-
old_param = new_to_old_keymap.get(model_param)
289+
if new_to_old_keymap is not None:
290+
# TODO: clean the following manual prefix together with
291+
# `purge_fqn_prefix` when we switch from creating Transformer to
292+
# creating model
293+
model_param = (
294+
"output.weight" if param == "output.weight" else f"model.{param}"
295+
)
296+
old_param = new_to_old_keymap[model_param]
297+
else:
298+
old_param = param
285299

286300
if old_param not in checkpoint:
287301
# Maybe this param is in other files
@@ -309,7 +323,9 @@ def update_state_dict(
309323

310324
# Update model state dict with checkpoint tensor
311325
state_dict[param] = checkpoint_tensor
312-
updated_states.add(param)
326+
327+
if updated_states is not None:
328+
updated_states.add(param)
313329

314330

315331
def format_tensor_info(tensor: torch.Tensor) -> str:
@@ -378,3 +394,59 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
378394
)
379395
if num_missing_weights > 0:
380396
raise ValueError(f"Missing {num_missing_weights} weights")
397+
398+
399+
# HACK: assuming single file for torchchat's converted checkpoints. We should
400+
# remove this after converging to torchchat's model building process.
401+
# In particular,
402+
# builder_args = BuilderArgs.from_args(args)
403+
# will tell us if there is a single file or a directory.
404+
TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True
405+
406+
def load_weights_from_torchchat_format(stage_module, distribution, device, model_config):
407+
"""
408+
Load the weights from torchchat format (single binary file), and fill into
409+
`stage_module`. Model config is needed b/c we permute wq and wk weights
410+
based on attn heads.
411+
"""
412+
stage_state_dict = stage_module.state_dict()
413+
# TODO: clean this up together with `purge_fqn_prefix` when we switch
414+
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
415+
416+
# Load checkpoint from torchchat cache
417+
default_cache_dir = Path(
418+
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
419+
).expanduser()
420+
# Distribution is like "meta-llama/Meta-Llama-3-8B-Instruct"
421+
# Join it with the default cache dir to get the checkpoint dir
422+
checkpoint_dir = default_cache_dir / distribution
423+
# Provide path in single-file case, provide dir in multi-file case. See
424+
# `_load_checkpoint`.
425+
if TORCHCHCAT_SINGLE_FILE_CHECKPOINT:
426+
checkpoint_path = checkpoint_dir / "model.pth"
427+
checkpoint_dir = None
428+
else:
429+
checkpoint_path = None
430+
# First, construct BuilderArgs
431+
args_dict = {
432+
"device": device,
433+
"checkpoint_dir": checkpoint_dir,
434+
"checkpoint_path": checkpoint_path,
435+
}
436+
builder_args = BuilderArgs(**args_dict)
437+
# Then, load the checkpoint using torchchat util
438+
checkpoint = _load_checkpoint(builder_args)
439+
440+
updated_states: Set[str] = set()
441+
# This step converts full tensor into DTensor
442+
update_state_dict(
443+
stage_state_dict,
444+
checkpoint,
445+
device,
446+
model_config=model_config,
447+
updated_states=updated_states,
448+
)
449+
450+
# Fill state dict into stage module
451+
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
452+
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")

0 commit comments

Comments
 (0)