Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8990f41

Browse files
committed
Remove weight map and file from update_state_dict
1 parent 2f6b296 commit 8990f41

File tree

1 file changed

+41
-51
lines changed

1 file changed

+41
-51
lines changed

torchchat/distributed/safetensor_utils.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def load_safetensor_weights(
166166
Tuple[int, int]: Number of updated weights and number of missing weights.
167167
"""
168168
stage_state_dict = stage_module.state_dict()
169-
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
170-
weight_map = purge_fqn_prefix(weight_map, "model.")
169+
if purge_model_prefix:
170+
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
171+
weight_map = purge_fqn_prefix(weight_map, "model.")
171172

172173
needed_files = get_needed_files(stage_state_dict, weight_map)
173174
updated_states: Set[str] = set()
@@ -181,9 +182,7 @@ def load_safetensor_weights(
181182
update_state_dict(
182183
stage_state_dict,
183184
checkpoint,
184-
weight_map,
185185
new_to_old_keymap,
186-
file,
187186
updated_states,
188187
device,
189188
model_config,
@@ -216,10 +215,13 @@ def load_safetensor_weights(
216215
return len(updated_states), len(missing_keys)
217216

218217

218+
# TODO: clean this up together with `purge_fqn_prefix` when we switch
219+
# from creating Transformer to creating model
219220
def purge_fqn_prefix(
220221
any_dict: Dict[str, Any],
221222
prefix: str,
222-
) -> Dict[str, torch.Tensor]:
223+
) -> Dict[str, Any]:
224+
"""Remove a prefix from all keys in a dictionary."""
223225
return {k.removeprefix(prefix): v for k, v in any_dict.items()}
224226

225227

@@ -262,64 +264,52 @@ def permute_weight_to_attn_heads(w, n_heads, head_dim, model_dim):
262264
def update_state_dict(
263265
state_dict: Dict[str, torch.Tensor],
264266
checkpoint: Dict[str, torch.Tensor],
265-
weight_map: Dict[str, str],
266267
new_to_old_keymap: Dict[str, str],
267-
file: str,
268268
updated_states: Set[str],
269269
device: torch.device,
270270
model_config: Optional[Dict] = None,
271271
):
272-
count_dtensors_loaded = 0
273272
# for handling attn head permuting
274273
num_heads = model_config.n_heads
275274
dim = model_config.dim
276275
num_local_heads = model_config.n_local_heads
277276
head_dim = model_config.head_dim
278277

279-
for param, file_with_param in weight_map.items():
280-
if file_with_param == file and param in state_dict:
281-
model_param = (
282-
"output.weight" if param == "output.weight" else f"model.{param}"
278+
for param in state_dict.keys():
279+
# TODO: clean this up together with `purge_fqn_prefix` when we switch
280+
# from creating Transformer to creating model
281+
model_param = (
282+
"output.weight" if param == "output.weight" else f"model.{param}"
283+
)
284+
old_param = new_to_old_keymap.get(model_param)
285+
286+
if old_param not in checkpoint:
287+
# Maybe this param is in other files
288+
continue
289+
290+
checkpoint_tensor = checkpoint[old_param]
291+
model_tensor = state_dict[param]
292+
293+
if "wq" in param:
294+
checkpoint_tensor = permute_weight_to_attn_heads(
295+
checkpoint_tensor, num_heads, head_dim, dim
296+
)
297+
elif "wk" in param:
298+
checkpoint_tensor = permute_weight_to_attn_heads(
299+
checkpoint_tensor, num_local_heads, head_dim, dim
283300
)
284-
old_param = new_to_old_keymap.get(model_param)
285-
286-
if old_param not in checkpoint:
287-
logger.warning(f"Missing {old_param} in checkpoint")
288-
continue
289-
290-
checkpoint_tensor = checkpoint[old_param]
291-
model_tensor = state_dict[param]
292-
293-
if "wq" in param:
294-
checkpoint_tensor = permute_weight_to_attn_heads(
295-
checkpoint_tensor, num_heads, head_dim, dim
296-
)
297-
elif "wk" in param:
298-
checkpoint_tensor = permute_weight_to_attn_heads(
299-
checkpoint_tensor, num_local_heads, head_dim, dim
300-
)
301-
302-
# Move checkpoint tensor to desired device
303-
checkpoint_tensor = checkpoint_tensor.to(device)
304-
305-
# here we need to check if the tensor is a DTensor and if so, adjust the
306-
# shape and placement to match the model DTensor.
307-
if isinstance(model_tensor, DTensor):
308-
state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor)
309-
count_dtensors_loaded += 1
310-
else:
311-
# regular tensor, just update directly
312-
state_dict[param] = checkpoint_tensor
313-
314-
# ensure matching dtypes
315-
state_dict[param] = state_dict[param].to(checkpoint_tensor.dtype)
316-
317-
assert state_dict[param].dtype == checkpoint_tensor.dtype
318-
319-
# log_tensor_info(param, state_dict[param])
320-
# logger.info(f"Loaded {param} from {file}")
321-
updated_states.add(param)
322-
# logger.info(f"Count of loaded DTensors: {count_dtensors_loaded}")
301+
302+
# Move checkpoint tensor to desired device
303+
checkpoint_tensor = checkpoint_tensor.to(device)
304+
305+
# here we need to check if the tensor is a DTensor and if so, adjust the
306+
# shape and placement to match the model DTensor.
307+
if isinstance(model_tensor, DTensor):
308+
checkpoint_tensor = convert_to_dtensor(checkpoint_tensor, model_tensor)
309+
310+
# Update model state dict with checkpoint tensor
311+
state_dict[param] = checkpoint_tensor
312+
updated_states.add(param)
323313

324314

325315
def format_tensor_info(tensor: torch.Tensor) -> str:

0 commit comments

Comments
 (0)