Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
builder_args.device
):
model = Model.from_params(builder_args.params_path)
state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict)
else:
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)
checkpoint = flamingo_meta_to_tune(checkpoint)

model.load_state_dict(checkpoint, assign=True, strict=True)

return model

Expand Down
31 changes: 31 additions & 0 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def __init__(self, config: ModelArgs) -> None:
# It should be assigned in the actual model implementation, if any.
self.text_transformer_args = None

self._register_load_state_dict_pre_hook(self._load_model_state_dict)

def build_model(self) -> nn.Module:
"""
Builds a model based on the provided configuration.
Expand All @@ -468,6 +470,35 @@ def _replace_known_params(self, params):
params[key] = patterns[value]
return params

def _load_model_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Updates the loaded internal model state dictionary to match the Model class structure.
Note that this is a temporary solution and will be removed once the model structure is finalized.
Args:
state_dict (dict): The state dictionary to load.
prefix (str): The prefix of the model.
local_metadata (dict): Local metadata.
strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys in the model.
missing_keys (list): List of missing keys.
unexpected_keys (list): List of unexpected keys.
error_msgs (list): List of error messages.
Returns:
dict: The updated state dictionary.
"""
for key in list(state_dict.keys()):
new_key = "model." + key
state_dict[new_key] = state_dict.pop(key)
return state_dict

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down
5 changes: 4 additions & 1 deletion torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
result = copy.deepcopy(gguf_name)
for gguf_string, replacement in _name_replacements:
result = result.replace(gguf_string, replacement)
result = "model." + result
return result


def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any:
if fqn == "":
return module
atoms = fqn.split(".")
# Expose the root module to users.
# Note this is temporary, and will be removed once we removed Model.model
if isinstance(module, Model):
module = module.model
curr = module
for a in atoms:
curr = getattr(curr, a)
Expand Down
Loading