Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3f2cf36

Browse files
committed
remove model.model in ckpt loading
1 parent 8d01d9b commit 3f2cf36

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

install/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ PYTORCH_NIGHTLY_VERSION=dev20240814
5353
VISION_NIGHTLY_VERSION=dev20240814
5454

5555
# Nightly version for torchtune
56-
TUNE_NIGHTLY_VERSION=dev20240916
56+
TUNE_NIGHTLY_VERSION=dev20240918
5757

5858

5959
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same

torchchat/cli/builder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,9 @@ def _load_model_default(builder_args, only_config=False):
388388
builder_args.device
389389
):
390390
model = Model.from_params(builder_args.params_path)
391-
state_dict = flamingo_meta_to_tune(checkpoint)
392-
model.model.load_state_dict(state_dict)
393-
else:
394-
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
395-
model.load_state_dict(checkpoint, assign=True, strict=True)
391+
checkpoint = flamingo_meta_to_tune(checkpoint)
392+
393+
model.load_state_dict(checkpoint, assign=True, strict=True)
396394

397395
return model
398396

torchchat/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ def __init__(self, config: ModelArgs) -> None:
442442
# It should be assigned in the actual model implementation, if any.
443443
self.text_transformer_args = None
444444

445+
self._register_load_state_dict_pre_hook(self._load_model_state_dict)
446+
445447
def build_model(self) -> nn.Module:
446448
"""
447449
Builds a model based on the provided configuration.
@@ -468,6 +470,13 @@ def _replace_known_params(self, params):
468470
params[key] = patterns[value]
469471
return params
470472

473+
def _load_model_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
474+
# 修改 state dict 中的键值
475+
for key in list(state_dict.keys()):
476+
new_key = 'model.' + key
477+
state_dict[new_key] = state_dict.pop(key)
478+
return state_dict
479+
471480
@abstractmethod
472481
def forward(self, *args, **kwargs):
473482
raise NotImplementedError("forward method is not implemented")

0 commit comments

Comments
 (0)