Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2f6b296

Browse files
committed
Add purge_fqn_prefix
1 parent 33d8d71 commit 2f6b296

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

torchchat/distributed/safetensor_utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
import json
1313
from torch.nn import Module
14-
from typing import Dict, Tuple, Set, Optional
14+
from typing import Any, Dict, Tuple, Set, Optional
1515

1616
from torch.distributed._tensor import DTensor
1717
from torchchat.distributed.dtensor_utils import convert_to_dtensor
@@ -165,17 +165,18 @@ def load_safetensor_weights(
165165
Returns:
166166
Tuple[int, int]: Number of updated weights and number of missing weights.
167167
"""
168-
stage_state_dict, weight_map = prepare_state_dict(
169-
stage_module, weight_map, purge_model_prefix
170-
)
168+
stage_state_dict = stage_module.state_dict()
169+
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
170+
weight_map = purge_fqn_prefix(weight_map, "model.")
171+
171172
needed_files = get_needed_files(stage_state_dict, weight_map)
172173
updated_states: Set[str] = set()
173174

174175
for file in needed_files:
175176
full_path = os.path.join(file_location, file)
176177
# logger.info(f"Loading checkpoint file: {full_path}")
177178
try:
178-
checkpoint = load_checkpoint(full_path, "cpu") # device)
179+
checkpoint = load_safetensor_file(full_path, "cpu") # device)
179180

180181
update_state_dict(
181182
stage_state_dict,
@@ -215,14 +216,11 @@ def load_safetensor_weights(
215216
return len(updated_states), len(missing_keys)
216217

217218

218-
def prepare_state_dict(
219-
module: Module, weight_map: Dict[str, str], purge_model_prefix: bool
219+
def purge_fqn_prefix(
220+
any_dict: Dict[str, Any],
221+
prefix: str,
220222
) -> Dict[str, torch.Tensor]:
221-
state_dict = module.state_dict()
222-
if purge_model_prefix:
223-
state_dict = {k.removeprefix("model."): v for k, v in state_dict.items()}
224-
weight_map = {k.removeprefix("model."): v for k, v in weight_map.items()}
225-
return state_dict, weight_map
223+
return {k.removeprefix(prefix): v for k, v in any_dict.items()}
226224

227225

228226
def get_needed_files(
@@ -242,7 +240,7 @@ def get_needed_files(
242240
return needed_files
243241

244242

245-
def load_checkpoint(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]:
243+
def load_safetensor_file(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]:
246244
tensors = {}
247245
with safe_open(full_path, framework="pt", device=device) as f:
248246
for k in f.keys():

0 commit comments

Comments
 (0)