Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c3b5965

Browse files
authored
Use default max_seq_length of 128 when loading ExecuTorch models (#1190)
1 parent 28914fd commit c3b5965

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

torchchat/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def chat(
818818
if text_transformer_args is not None
819819
else 2048
820820
),
821+
max_seq_length
821822
)
822823

823824
max_seq_length = (

torchchat/model.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def forward(
127127
)
128128

129129
return self.decoder(decoder_input, input_pos=input_pos)
130-
130+
131131
def setup_caches(self, batch_size, max_seq_len) -> None:
132132
self.decoder.setup_caches(batch_size, max_seq_len)
133-
133+
134134
def _encoder_feature_select(self, encoder_output) -> Tensor:
135135
selected_image_feature = encoder_output[1][0].view(
136136
*encoder_output[1][0].shape[2:]
@@ -154,7 +154,7 @@ def _get_decoder_input(
154154
image_embeds = self.mm_projector(encoder_output)
155155
if post_tokens is None:
156156
return torch.cat((pre_img_embed, image_embeds), dim=1)
157-
157+
158158
post_img_embed = self.tok_embeddings(post_tokens)
159159
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
160160

@@ -227,7 +227,7 @@ def _llava(cls):
227227
},
228228
fusion_class=ConcateFusion,
229229
)
230-
230+
231231
@classmethod
232232
def get_recipe(cls, model_type):
233233
match model_type:
@@ -338,7 +338,7 @@ def _sanity_check(
338338
def from_params(cls, params_path):
339339
with open(params_path, "r") as f:
340340
loaded_params = json.loads(f.read())
341-
341+
342342
if (model_type_name := loaded_params.get("model_type", None)) is None:
343343
# The model params is in the transformer_args format
344344
# set the model_type to TextOnly and reformat the params
@@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
460460
modules[name] = module_class(**config_args)
461461

462462
return recipe.fusion_class(**modules)
463-
463+
464464
def _replace_known_params(self, params):
465465
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
466466
for key, value in params.items():
467467
if isinstance(value, Hashable) and value in patterns:
468468
params[key] = patterns[value]
469469
return params
470-
470+
471471
@abstractmethod
472472
def forward(self, *args, **kwargs):
473473
raise NotImplementedError("forward method is not implemented")
@@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
939939
self.model_ = exec_lib._load_for_executorch(str(path))
940940

941941
self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])
942-
942+
# TODO: attempt to use "get_max_seq_len" method on the model after
943+
# ExecuTorch bug is fixed.
944+
max_seq_len = 128
945+
# try:
946+
# max_seq_len = self.model_.run_method("get_max_seq_len", [])
947+
# except Exception as e:
948+
# pass
949+
self.text_transformer_args.max_seq_length = max_seq_len
950+
943951
def forward(self, x, input_pos):
944952
# model_.forward expects inputs to be wrapped in a tuple
945953
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
@@ -958,6 +966,6 @@ def forward(self, x, input_pos):
958966

959967
def setup_caches(self, max_batch_size, max_seq_length):
960968
pass
961-
969+
962970
except:
963971
pass

0 commit comments

Comments
 (0)