@@ -127,10 +127,10 @@ def forward(
127127            )
128128
129129        return  self .decoder (decoder_input , input_pos = input_pos )
130-      
130+ 
131131    def  setup_caches (self , batch_size , max_seq_len ) ->  None :
132132        self .decoder .setup_caches (batch_size , max_seq_len )
133-      
133+ 
134134    def  _encoder_feature_select (self , encoder_output ) ->  Tensor :
135135        selected_image_feature  =  encoder_output [1 ][0 ].view (
136136            * encoder_output [1 ][0 ].shape [2 :]
@@ -154,7 +154,7 @@ def _get_decoder_input(
154154            image_embeds  =  self .mm_projector (encoder_output )
155155            if  post_tokens  is  None :
156156                return  torch .cat ((pre_img_embed , image_embeds ), dim = 1 )
157-              
157+ 
158158            post_img_embed  =  self .tok_embeddings (post_tokens )
159159            return  torch .cat ((pre_img_embed , image_embeds , post_img_embed ), dim = 1 )
160160
@@ -227,7 +227,7 @@ def _llava(cls):
227227            },
228228            fusion_class = ConcateFusion ,
229229        )
230-      
230+ 
231231    @classmethod  
232232    def  get_recipe (cls , model_type ):
233233        match  model_type :
@@ -338,7 +338,7 @@ def _sanity_check(
338338    def  from_params (cls , params_path ):
339339        with  open (params_path , "r" ) as  f :
340340            loaded_params  =  json .loads (f .read ())
341-          
341+ 
342342        if  (model_type_name  :=  loaded_params .get ("model_type" , None )) is  None :
343343            # The model params is in the transformer_args format 
344344            # set the model_type to TextOnly and reformat the params 
@@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
460460                modules [name ] =  module_class (** config_args )
461461
462462        return  recipe .fusion_class (** modules )
463-      
463+ 
464464    def  _replace_known_params (self , params ):
465465        patterns  =  {"QuickGELUActivation()" : QuickGELUActivation ()}
466466        for  key , value  in  params .items ():
467467            if  isinstance (value , Hashable ) and  value  in  patterns :
468468                params [key ] =  patterns [value ]
469469        return  params 
470-      
470+ 
471471    @abstractmethod  
472472    def  forward (self , * args , ** kwargs ):
473473        raise  NotImplementedError ("forward method is not implemented" )
@@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
939939            self .model_  =  exec_lib ._load_for_executorch (str (path ))
940940
941941            self .text_transformer_args  =  TransformerArgs .from_params (self .config .transformer_args ["text" ])
942-             
942+             # TODO: attempt to use "get_max_seq_len" method on the model after 
943+             # ExecuTorch bug is fixed. 
944+             max_seq_len  =  128 
945+             # try: 
946+             #     max_seq_len = self.model_.run_method("get_max_seq_len", []) 
947+             # except Exception as e: 
948+             #     pass 
949+             self .text_transformer_args .max_seq_length  =  max_seq_len 
950+ 
943951        def  forward (self , x , input_pos ):
944952            # model_.forward expects inputs to be wrapped in a tuple 
945953            forward_inputs  =  (x .to (torch .long ), input_pos .to (torch .long ))
@@ -958,6 +966,6 @@ def forward(self, x, input_pos):
958966
959967        def  setup_caches (self , max_batch_size , max_seq_length ):
960968            pass 
961-          
969+ 
962970except :
963971    pass 
0 commit comments