Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d18da59

Browse files
committed
expose internal model attribute:
1 parent bd8ff07 commit d18da59

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

torchchat/model.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,13 +470,42 @@ def _replace_known_params(self, params):
470470
params[key] = patterns[value]
471471
return params
472472

473-
def _load_model_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
474-
# 修改 state dict 中的键值
473+
def _load_model_state_dict(
474+
self,
475+
state_dict,
476+
prefix,
477+
local_metadata,
478+
strict,
479+
missing_keys,
480+
unexpected_keys,
481+
error_msgs,
482+
):
483+
# update key names to match the new model
475484
for key in list(state_dict.keys()):
476-
new_key = 'model.' + key
485+
new_key = "model." + key
477486
state_dict[new_key] = state_dict.pop(key)
478487
return state_dict
479-
488+
489+
def __getattr__(self, name):
490+
"""
491+
Rewrite __getattr__ to search attribute in Model and its model attribute.
492+
Note that this is a temporary solution to expose internal model attributes to the user.
493+
494+
:param name: The name of the attribute to get.
495+
:return: The attribute value if found, otherwise raise AttributeError.
496+
"""
497+
try:
498+
return super().__getattribute__(name)
499+
except AttributeError:
500+
pass
501+
502+
try:
503+
return getattr(self.model, name)
504+
except AttributeError:
505+
raise AttributeError(
506+
f"'{type(self).__name__}' object has no attribute '{name}'"
507+
)
508+
480509
@abstractmethod
481510
def forward(self, *args, **kwargs):
482511
raise NotImplementedError("forward method is not implemented")

0 commit comments

Comments
 (0)