Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,9 @@ def _load_model_default(builder_args, only_config=False):
builder_args.device
):
model = Model.from_params(builder_args.params_path)
state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict)
else:
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)
checkpoint = flamingo_meta_to_tune(checkpoint)

model.load_state_dict(checkpoint, assign=True, strict=True)

return model

Expand Down
9 changes: 9 additions & 0 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def __init__(self, config: ModelArgs) -> None:
# It should be assigned in the actual model implementation, if any.
self.text_transformer_args = None

self._register_load_state_dict_pre_hook(self._load_model_state_dict)

def build_model(self) -> nn.Module:
"""
Builds a model based on the provided configuration.
Expand All @@ -468,6 +470,13 @@ def _replace_known_params(self, params):
params[key] = patterns[value]
return params

def _load_model_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on what the hook does

# 修改 state dict 中的键值
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multilingual comments

for key in list(state_dict.keys()):
new_key = 'model.' + key
state_dict[new_key] = state_dict.pop(key)
return state_dict

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down
1 change: 0 additions & 1 deletion torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
result = copy.deepcopy(gguf_name)
for gguf_string, replacement in _name_replacements:
result = result.replace(gguf_string, replacement)
result = "model." + result
return result


Expand Down
Loading