@@ -127,10 +127,10 @@ def forward(
127127 )
128128
129129 return self .decoder (decoder_input , input_pos = input_pos )
130-
130+
131131 def setup_caches (self , batch_size , max_seq_len ) -> None :
132132 self .decoder .setup_caches (batch_size , max_seq_len )
133-
133+
134134 def _encoder_feature_select (self , encoder_output ) -> Tensor :
135135 selected_image_feature = encoder_output [1 ][0 ].view (
136136 * encoder_output [1 ][0 ].shape [2 :]
@@ -154,7 +154,7 @@ def _get_decoder_input(
154154 image_embeds = self .mm_projector (encoder_output )
155155 if post_tokens is None :
156156 return torch .cat ((pre_img_embed , image_embeds ), dim = 1 )
157-
157+
158158 post_img_embed = self .tok_embeddings (post_tokens )
159159 return torch .cat ((pre_img_embed , image_embeds , post_img_embed ), dim = 1 )
160160
@@ -227,7 +227,7 @@ def _llava(cls):
227227 },
228228 fusion_class = ConcateFusion ,
229229 )
230-
230+
231231 @classmethod
232232 def get_recipe (cls , model_type ):
233233 match model_type :
@@ -338,7 +338,7 @@ def _sanity_check(
338338 def from_params (cls , params_path ):
339339 with open (params_path , "r" ) as f :
340340 loaded_params = json .loads (f .read ())
341-
341+
342342 if (model_type_name := loaded_params .get ("model_type" , None )) is None :
343343 # The model params is in the transformer_args format
344344 # set the model_type to TextOnly and reformat the params
@@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
460460 modules [name ] = module_class (** config_args )
461461
462462 return recipe .fusion_class (** modules )
463-
463+
464464 def _replace_known_params (self , params ):
465465 patterns = {"QuickGELUActivation()" : QuickGELUActivation ()}
466466 for key , value in params .items ():
467467 if isinstance (value , Hashable ) and value in patterns :
468468 params [key ] = patterns [value ]
469469 return params
470-
470+
471471 @abstractmethod
472472 def forward (self , * args , ** kwargs ):
473473 raise NotImplementedError ("forward method is not implemented" )
@@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
939939 self .model_ = exec_lib ._load_for_executorch (str (path ))
940940
941941 self .text_transformer_args = TransformerArgs .from_params (self .config .transformer_args ["text" ])
942-
942+ # TODO: attempt to use "get_max_seq_len" method on the model after
943+ # ExecuTorch bug is fixed.
944+ max_seq_len = 128
945+ # try:
946+ # max_seq_len = self.model_.run_method("get_max_seq_len", [])
947+ # except Exception as e:
948+ # pass
949+ self .text_transformer_args .max_seq_length = max_seq_len
950+
943951 def forward (self , x , input_pos ):
944952 # model_.forward expects inputs to be wrapped in a tuple
945953 forward_inputs = (x .to (torch .long ), input_pos .to (torch .long ))
@@ -958,6 +966,6 @@ def forward(self, x, input_pos):
958966
959967 def setup_caches (self , max_batch_size , max_seq_length ):
960968 pass
961-
969+
962970except :
963971 pass
0 commit comments