Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9c86055

Browse files
committed
[Distributed] Fix correctness issue in TC load path
1 parent 9ef2171 commit 9c86055

File tree

3 files changed

+72
-64
lines changed

3 files changed

+72
-64
lines changed

dist_run.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,35 @@ def _load_model_weights(
125125
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
126126
device (torch.device): The device to load the weights onto.
127127
model_config (ModelArgs): The model config.
128-
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
128+
chpt_from (str): The checkpoint format to load the weights from, e.g. "tc" or "hf".
129129
"""
130-
if chpt_from == "hf":
130+
# Valid chpt_from values: hf, tc,
131+
# or if you want to load from a specific dir, e.g.
132+
# tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo
133+
str_lst = chpt_from.split(":")
134+
assert len(str_lst) == 1 or len(str_lst) == 2, "Invalid --chpt_from format"
135+
chpt_format = str_lst[0] # hf or tc
136+
chpt_distribution = str_lst[1] if len(str_lst) == 2 else distribution
137+
138+
stage_state_dict = stage_module.state_dict()
139+
if chpt_format == "hf":
131140
# This format stands for: index file + multiple binary files
132-
load_weights_from_hf_format(stage_module, distribution, device, model_config)
133-
elif chpt_from == "torchchat":
141+
stage_state_dict = load_weights_from_hf_format(
142+
stage_state_dict, chpt_distribution, device, model_config
143+
)
144+
elif chpt_format == "tc":
134145
# This format stands for:
135146
# single binary file, OR
136147
# multiple binary files without index files.
137-
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
148+
stage_state_dict = load_weights_from_torchchat_format(
149+
stage_state_dict, chpt_distribution, device, model_config
150+
)
138151
else:
139-
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
152+
raise ValueError(f"Unknown checkpoint format: {chpt_format}")
153+
154+
# Fill state dict into stage module
155+
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
156+
logger.info(f"Successfully loaded weights into stage module")
140157

141158

142159
def _encode_strings(
@@ -589,9 +606,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
589606
parser.add_argument(
590607
"--chpt-from",
591608
type=str,
592-
default="hf", # TODO: change to torchchat once we support it well
593-
help="Checkpoint format to load from",
594-
choices=["hf", "torchchat"],
609+
default="tc",
610+
help="Checkpoint to load from, e.g. `hf` or `tc`, or "
611+
"`tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo`",
595612
)
596613
args = parser.parse_args()
597614

torchchat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"where": "Return directory containing downloaded model artifacts",
5050
"server": "[WIP] Starts a locally hosted REST server for model interaction",
5151
"eval": "Evaluate a model via lm-eval",
52+
"save_quant": "Save a quantized model",
5253
}
5354
for verb, description in VERB_HELP.items():
5455
subparser = subparsers.add_parser(verb, help=description)
@@ -115,5 +116,10 @@
115116
from torchchat.cli.download import remove_main
116117

117118
remove_main(args)
119+
elif args.command == "save_quant":
120+
#check_args(args, "save_quant")
121+
from torchchat.save_quant import main as save_quant_main
122+
123+
save_quant_main(args)
118124
else:
119125
parser.print_help()

torchchat/distributed/checkpoint_utils.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,12 @@ def remap_weight_keys(dictionary):
118118

119119

120120
def load_weights_per_map(
121-
stage_module: Module,
121+
stage_state_dict,
122122
weight_map: Dict[str, str],
123123
file_location: str,
124124
new_to_old_keymap: Dict[str, str],
125125
device: torch.device,
126126
is_safetensor: bool,
127-
purge_model_prefix: bool = True,
128-
ignore_cache_layers: bool = True,
129127
model_config: Optional[Dict] = None,
130128
) -> Tuple[int, int]:
131129
"""
@@ -138,18 +136,11 @@ def load_weights_per_map(
138136
new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
139137
device (torch.device): The device to load tensors onto.
140138
is_safetensor (bool): Whether the files are safetensors.
141-
purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
142-
ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
143139
model_config (Optional[Dict]): Model configuration.
144140
145141
Returns:
146142
Tuple[int, int]: Number of updated weights and number of missing weights.
147143
"""
148-
stage_state_dict = stage_module.state_dict()
149-
if purge_model_prefix:
150-
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
151-
weight_map = purge_fqn_prefix(weight_map, "model.")
152-
153144
needed_files = get_needed_files(stage_state_dict, weight_map)
154145
updated_states: Set[str] = set()
155146

@@ -175,27 +166,9 @@ def load_weights_per_map(
175166
logger.error(f"Error during checkpoint processing:")
176167
raise e
177168

178-
missing_keys = handle_missing_keys(
179-
stage_state_dict, updated_states, ignore_cache_layers
169+
check_for_missing_keys(
170+
stage_state_dict, updated_states, ignore_cache_layers=True
180171
)
181-
# log_loading_status(missing_keys, updated_states)
182-
if missing_keys:
183-
logger.warning(
184-
f"Partially updated state dict. Missing {len(missing_keys)} keys: {missing_keys}"
185-
)
186-
else:
187-
logger.info("Fully updated state dict.")
188-
189-
logger.info(f"Loading {len(updated_states)} weights into stage dict")
190-
# precount, premap = record_module_dtypes(stage_module)
191-
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
192-
# postcount, postmap = record_module_dtypes(stage_module)
193-
# logger.info(f"{precount=}, {postcount=}")
194-
# logger.info(f"{premap=}, {postmap=}")
195-
196-
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
197-
198-
return len(updated_states), len(missing_keys)
199172

200173

201174
# TODO: clean this up together with `purge_fqn_prefix` when we switch
@@ -287,14 +260,15 @@ def update_state_dict(
287260
checkpoint_tensor = checkpoint[old_param]
288261
model_tensor = state_dict[param]
289262

290-
if "wq" in param:
291-
checkpoint_tensor = permute_weight_to_attn_heads(
292-
checkpoint_tensor, num_heads, head_dim, dim
293-
)
294-
elif "wk" in param:
295-
checkpoint_tensor = permute_weight_to_attn_heads(
296-
checkpoint_tensor, num_local_heads, head_dim, dim
297-
)
263+
if new_to_old_keymap is not None:
264+
if "wq" in param:
265+
checkpoint_tensor = permute_weight_to_attn_heads(
266+
checkpoint_tensor, num_heads, head_dim, dim
267+
)
268+
elif "wk" in param:
269+
checkpoint_tensor = permute_weight_to_attn_heads(
270+
checkpoint_tensor, num_local_heads, head_dim, dim
271+
)
298272

299273
# Move checkpoint tensor to desired device
300274
checkpoint_tensor = checkpoint_tensor.to(device)
@@ -324,10 +298,10 @@ def clean_cache_keys(input_set: Set[str]) -> Set[str]:
324298
}
325299

326300

327-
def handle_missing_keys(
301+
def check_for_missing_keys(
328302
state_dict: Dict[str, torch.Tensor],
329303
updated_states: Set[str],
330-
ignore_cache_layers: bool,
304+
ignore_cache_layers: bool = True,
331305
) -> Set[str]:
332306
"""This function handles 'expected' missing keys from the checkpoint update set.
333307
This is used for ignoring cache, rope freqs, and mask layers that are generated, rather than persisted
@@ -342,7 +316,13 @@ def handle_missing_keys(
342316
logger.info(
343317
f"Ignoring {start_len - after_len} missing cache, freqs, mask layers"
344318
)
345-
return missing_keys
319+
320+
if len(missing_keys) > 0:
321+
from itertools import islice
322+
raise RuntimeError(
323+
f"Missing {len(missing_keys)} weights, for example: "
324+
f"{list(islice(missing_keys, 10))}"
325+
)
346326

347327

348328
def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
@@ -355,10 +335,10 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
355335
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
356336

357337

358-
def load_weights_from_hf_format(stage_module, distribution, device, model_config):
338+
def load_weights_from_hf_format(stage_state_dict, distribution, device, model_config):
359339
"""
360340
Load the weights from Hugging Face format (index file + multiple safetensor
361-
files), and fill into `stage_module`. Model config is needed b/c we permute
341+
files), and fill into `stage_state_dict`. Model config is needed b/c we permute
362342
wq and wk weights based on attn heads.
363343
"""
364344
# Get the weight map for a given HF model id
@@ -382,21 +362,21 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
382362
weight_dir = os.path.dirname(index_file)
383363
logger.info(f"Loading weights from: {weight_dir}")
384364

365+
# TODO: clean this up together with `purge_fqn_prefix` when we switch
366+
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
367+
weight_map = purge_fqn_prefix(weight_map, "model.")
368+
385369
# Load the weights into the stage module
386-
num_loaded_weights, num_missing_weights = load_weights_per_map(
387-
stage_module,
370+
load_weights_per_map(
371+
stage_state_dict,
388372
weight_map,
389373
weight_dir,
390374
new_to_old_keymap,
391375
device,
392376
is_safetensor,
393377
model_config=model_config,
394378
)
395-
logger.info(
396-
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
397-
)
398-
if num_missing_weights > 0:
399-
raise ValueError(f"Missing {num_missing_weights} weights")
379+
return stage_state_dict
400380

401381

402382
# HACK: assuming single file for torchchat's converted checkpoints. We should
@@ -406,13 +386,12 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
406386
# will tell us if there is a single file or a directory.
407387
TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True
408388

409-
def load_weights_from_torchchat_format(stage_module, distribution, device, model_config):
389+
def load_weights_from_torchchat_format(stage_state_dict, distribution, device, model_config):
410390
"""
411391
Load the weights from torchchat format (single binary file), and fill into
412392
`stage_module`. Model config is needed b/c we permute wq and wk weights
413393
based on attn heads.
414394
"""
415-
stage_state_dict = stage_module.state_dict()
416395
# TODO: clean this up together with `purge_fqn_prefix` when we switch
417396
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
418397

@@ -437,6 +416,10 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
437416
"checkpoint_path": checkpoint_path,
438417
}
439418
builder_args = BuilderArgs(**args_dict)
419+
logger.info(
420+
"Loading checkpoint from: "
421+
f"{builder_args.checkpoint_dir or builder_args.checkpoint_path}"
422+
)
440423
# Then, load the checkpoint using torchchat util
441424
checkpoint = _load_checkpoint(builder_args)
442425

@@ -450,6 +433,8 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
450433
updated_states=updated_states,
451434
)
452435

453-
# Fill state dict into stage module
454-
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
455-
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
436+
check_for_missing_keys(
437+
stage_state_dict, updated_states, ignore_cache_layers=True
438+
)
439+
440+
return stage_state_dict

0 commit comments

Comments
 (0)