Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1eff939

Browse files
committed
make text_transformer_args a real attribute
1 parent 304fece commit 1eff939

File tree

2 files changed

+14
-29
lines changed

2 files changed

+14
-29
lines changed

torchchat/model.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def __init__(self, config: ModelArgs) -> None:
304304
super().__init__()
305305
self.config = config
306306
self.model = self.build_model()
307+
self.text_transformer_args = None
307308

308309
def build_model(self) -> nn.Module:
309310
"""
@@ -331,11 +332,6 @@ def forward(self, *args, **kwargs):
331332
@abstractmethod
332333
def setup_caches(self, *args, **kwargs):
333334
raise NotImplementedError("setup_caches method is not implemented")
334-
335-
@property
336-
@abstractmethod
337-
def text_transformer_args(self):
338-
raise NotImplementedError("no text_transformer_args is created")
339335

340336
@classmethod
341337
def _get_model_instance(cls, config: ModelArgs):
@@ -371,15 +367,15 @@ def from_gguf(cls, gguf_path: str, **kwargs):
371367

372368

373369
class TextOnlyModel(Model):
370+
def __init__(self, config: ModelArgs) -> None:
371+
super().__init__(config)
372+
self.text_transformer_args = self.model.config
373+
374374
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
375375
return self.model(tokens, input_pos)
376376

377377
def setup_caches(self, max_batch_size, max_seq_length):
378378
self.model.setup_caches(max_batch_size, max_seq_length)
379-
380-
@property
381-
def text_transformer_args(self):
382-
return self.model.model.config
383379

384380

385381
class Llama31Model(Model):
@@ -391,11 +387,6 @@ def setup_caches(self, max_batch_size, dtype):
391387

392388
def reset_caches(self):
393389
self.model.reset_caches()
394-
395-
@property
396-
def text_transformer_args(self):
397-
# TODO: add support for llama3_1
398-
return None
399390

400391

401392
class FlamingoModel(Model):
@@ -416,11 +407,7 @@ def setup_caches(self, max_batch_size, dtype):
416407

417408
def reset_caches(self):
418409
self.model.reset_caches()
419-
420-
@property
421-
def text_transformer_args(self):
422-
# TODO: add support for flamingo
423-
return None
410+
424411

425412

426413
MODEL_TYPE_TO_CLASS = {
@@ -813,7 +800,7 @@ def __init__(self, config, path) -> None:
813800
self.config = config
814801
self.model_ = exec_lib._load_for_executorch(str(path))
815802

816-
self.text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"])
803+
self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])
817804

818805
def forward(self, x, input_pos):
819806
# model_.forward expects inputs to be wrapped in a tuple

torchchat/utils/gguf_loader.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -544,15 +544,13 @@ def load_model(gguf_file: str) -> torch.nn.Module:
544544
model_args = ModelArgs(
545545
{
546546
"text": {
547-
"config": {
548-
"dim": metadata[f"{arch}.embedding_length"],
549-
"n_layers": metadata[f"{arch}.block_count"],
550-
"n_heads": metadata[f"{arch}.attention.head_count"],
551-
"n_local_heads": metadata[f"{arch}.attention.head_count_kv"],
552-
"vocab_size": len(metadata["tokenizer.ggml.tokens"]),
553-
"norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
554-
"hidden_dim": metadata[f"{arch}.feed_forward_length"],
555-
}
547+
"dim": metadata[f"{arch}.embedding_length"],
548+
"n_layers": metadata[f"{arch}.block_count"],
549+
"n_heads": metadata[f"{arch}.attention.head_count"],
550+
"n_local_heads": metadata[f"{arch}.attention.head_count_kv"],
551+
"vocab_size": len(metadata["tokenizer.ggml.tokens"]),
552+
"norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
553+
"hidden_dim": metadata[f"{arch}.feed_forward_length"],
556554
}
557555
}
558556
)

0 commit comments

Comments
 (0)