@@ -304,6 +304,7 @@ def __init__(self, config: ModelArgs) -> None:
304304 super ().__init__ ()
305305 self .config = config
306306 self .model = self .build_model ()
307+ self .text_transformer_args = None
307308
308309 def build_model (self ) -> nn .Module :
309310 """
@@ -331,11 +332,6 @@ def forward(self, *args, **kwargs):
331332 @abstractmethod
332333 def setup_caches (self , * args , ** kwargs ):
333334 raise NotImplementedError ("setup_caches method is not implemented" )
334-
335- @property
336- @abstractmethod
337- def text_transformer_args (self ):
338- raise NotImplementedError ("no text_transformer_args is created" )
339335
340336 @classmethod
341337 def _get_model_instance (cls , config : ModelArgs ):
@@ -371,15 +367,15 @@ def from_gguf(cls, gguf_path: str, **kwargs):
371367
372368
373369class TextOnlyModel (Model ):
370+ def __init__ (self , config : ModelArgs ) -> None :
371+ super ().__init__ (config )
372+ self .text_transformer_args = self .model .config
373+
374374 def forward (self , tokens : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
375375 return self .model (tokens , input_pos )
376376
377377 def setup_caches (self , max_batch_size , max_seq_length ):
378378 self .model .setup_caches (max_batch_size , max_seq_length )
379-
380- @property
381- def text_transformer_args (self ):
382- return self .model .model .config
383379
384380
385381class Llama31Model (Model ):
@@ -391,11 +387,6 @@ def setup_caches(self, max_batch_size, dtype):
391387
392388 def reset_caches (self ):
393389 self .model .reset_caches ()
394-
395- @property
396- def text_transformer_args (self ):
397- # TODO: add support for llama3_1
398- return None
399390
400391
401392class FlamingoModel (Model ):
@@ -416,11 +407,7 @@ def setup_caches(self, max_batch_size, dtype):
416407
417408 def reset_caches (self ):
418409 self .model .reset_caches ()
419-
420- @property
421- def text_transformer_args (self ):
422- # TODO: add support for flamingo
423- return None
410+
424411
425412
426413MODEL_TYPE_TO_CLASS = {
@@ -813,7 +800,7 @@ def __init__(self, config, path) -> None:
813800 self .config = config
814801 self .model_ = exec_lib ._load_for_executorch (str (path ))
815802
816- self .text_transformer_config = TransformerArgs .from_params (self .config .transformer_args ["text" ])
803+ self .text_transformer_args = TransformerArgs .from_params (self .config .transformer_args ["text" ])
817804
818805 def forward (self , x , input_pos ):
819806 # model_.forward expects inputs to be wrapped in a tuple
0 commit comments