diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 5dbf48529..cd0d05056 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -388,11 +388,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: builder_args.device ): model = Model.from_params(builder_args.params_path) - state_dict = flamingo_meta_to_tune(checkpoint) - model.model.load_state_dict(state_dict) - else: - checkpoint = {"model." + k: v for k, v in checkpoint.items()} - model.load_state_dict(checkpoint, assign=True, strict=True) + checkpoint = flamingo_meta_to_tune(checkpoint) + + model.load_state_dict(checkpoint, assign=True, strict=True) return model diff --git a/torchchat/model.py b/torchchat/model.py index 79bd1f188..089fef9aa 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -442,6 +442,8 @@ def __init__(self, config: ModelArgs) -> None: # It should be assigned in the actual model implementation, if any. self.text_transformer_args = None + self._register_load_state_dict_pre_hook(self._load_model_state_dict) + def build_model(self) -> nn.Module: """ Builds a model based on the provided configuration. @@ -468,6 +470,35 @@ def _replace_known_params(self, params): params[key] = patterns[value] return params + def _load_model_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Updates the loaded internal model state dictionary to match the Model class structure. + Note that this is a temporary solution and will be removed once the model structure is finalized. + Args: + state_dict (dict): The state dictionary to load. + prefix (str): The prefix of the model. + local_metadata (dict): Local metadata. + strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys in the model. + missing_keys (list): List of missing keys. + unexpected_keys (list): List of unexpected keys. + error_msgs (list): List of error messages. + Returns: + dict: The updated state dictionary. + """ + for key in list(state_dict.keys()): + new_key = "model." + key + state_dict[new_key] = state_dict.pop(key) + return state_dict + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 309ff807c..772c6f7dd 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -47,7 +47,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: result = copy.deepcopy(gguf_name) for gguf_string, replacement in _name_replacements: result = result.replace(gguf_string, replacement) - result = "model." + result return result @@ -55,6 +54,10 @@ def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any: if fqn == "": return module atoms = fqn.split(".") + # Expose the root module to users. + # Note this is temporary, and will be removed once we removed Model.model + if isinstance(module, Model): + module = module.model curr = module for a in atoms: curr = getattr(curr, a)