Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8cd0936

Browse files
committed
bring TransformerArgs back to Transformer
1 parent 141fea0 commit 8cd0936

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchchat/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def from_params(cls, params_path):
210210
# The model params is in the transformer_args format
211211
# set the model_type to TextOnly and reformat the params
212212
model_type = ModelType.TextOnly
213-
transformer_args = {"text": {"config": loaded_params}}
213+
transformer_args = {"text": loaded_params}
214214
else:
215215
model_type = ModelType(model_type_name)
216216
transformer_args = {
@@ -317,7 +317,10 @@ def build_model(self) -> nn.Module:
317317
modules = {}
318318
for name, module_class in recipe.modules.items():
319319
config_args = self.config.transformer_args[name]
320-
modules[name] = module_class(**config_args)
320+
if module_class == Transformer:
321+
modules[name] = module_class(TransformerArgs.from_params(config_args))
322+
else:
323+
modules[name] = module_class(**config_args)
321324

322325
return recipe.fusion_class(**modules)
323326

@@ -424,9 +427,8 @@ def get_text_transformer_args(self):
424427

425428

426429
class Transformer(nn.Module):
427-
def __init__(self, config: Dict[str, Any]) -> None:
430+
def __init__(self, config: TransformerArgs) -> None:
428431
super().__init__()
429-
config = TransformerArgs.from_params(config)
430432
self.config = config
431433
layers_per_stage = config.n_layers // config.n_stages
432434

0 commit comments

Comments
 (0)