diff --git a/torchchat/generate.py b/torchchat/generate.py index 9e60f9494..14c4832e3 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -818,6 +818,7 @@ def chat( if text_transformer_args is not None else 2048 ), + max_seq_length ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index 3300ebee9..336eb864c 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -127,10 +127,10 @@ def forward( ) return self.decoder(decoder_input, input_pos=input_pos) - + def setup_caches(self, batch_size, max_seq_len) -> None: self.decoder.setup_caches(batch_size, max_seq_len) - + def _encoder_feature_select(self, encoder_output) -> Tensor: selected_image_feature = encoder_output[1][0].view( *encoder_output[1][0].shape[2:] @@ -154,7 +154,7 @@ def _get_decoder_input( image_embeds = self.mm_projector(encoder_output) if post_tokens is None: return torch.cat((pre_img_embed, image_embeds), dim=1) - + post_img_embed = self.tok_embeddings(post_tokens) return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) @@ -227,7 +227,7 @@ def _llava(cls): }, fusion_class=ConcateFusion, ) - + @classmethod def get_recipe(cls, model_type): match model_type: @@ -338,7 +338,7 @@ def _sanity_check( def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - + if (model_type_name := loaded_params.get("model_type", None)) is None: # The model params is in the transformer_args format # set the model_type to TextOnly and reformat the params @@ -460,14 +460,14 @@ def build_model(self) -> nn.Module: modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) - + def _replace_known_params(self, params): patterns = {"QuickGELUActivation()": QuickGELUActivation()} for key, value in params.items(): if isinstance(value, Hashable) and value in patterns: params[key] = patterns[value] return params - + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") @@ -939,7 +939,15 @@ def __init__(self, config, path) -> None: self.model_ = exec_lib._load_for_executorch(str(path)) self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"]) - + # TODO: attempt to use "get_max_seq_len" method on the model after + # ExecuTorch bug is fixed. + max_seq_len = 128 + # try: + # max_seq_len = self.model_.run_method("get_max_seq_len", []) + # except Exception as e: + # pass + self.text_transformer_args.max_seq_length = max_seq_len + def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple forward_inputs = (x.to(torch.long), input_pos.to(torch.long)) @@ -958,6 +966,6 @@ def forward(self, x, input_pos): def setup_caches(self, max_batch_size, max_seq_length): pass - + except: pass