Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7ddfdf8

Browse files
committed
Use default max_seq_length of 128 when loading ExecuTorch models
We just defaulted on export, we should also default on load. ExecuTorch has API to get the max seq length, but the Python API currently segfaults when the method is missing. ghstack-source-id: ba4f498 Pull Request resolved: #1184
1 parent 87e3c6e commit 7ddfdf8

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

torchchat/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def chat(
818818
if text_transformer_args is not None
819819
else 2048
820820
),
821+
max_seq_length
821822
)
822823

823824
max_seq_length = (

torchchat/model.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def forward(
127127
)
128128

129129
return self.decoder(decoder_input, input_pos=input_pos)
130-
130+
131131
def setup_caches(self, batch_size, max_seq_len) -> None:
132132
self.decoder.setup_caches(batch_size, max_seq_len)
133-
133+
134134
def _encoder_feature_select(self, encoder_output) -> Tensor:
135135
selected_image_feature = encoder_output[1][0].view(
136136
*encoder_output[1][0].shape[2:]
@@ -154,7 +154,7 @@ def _get_decoder_input(
154154
image_embeds = self.mm_projector(encoder_output)
155155
if post_tokens is None:
156156
return torch.cat((pre_img_embed, image_embeds), dim=1)
157-
157+
158158
post_img_embed = self.tok_embeddings(post_tokens)
159159
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
160160

@@ -227,7 +227,7 @@ def _llava(cls):
227227
},
228228
fusion_class=ConcateFusion,
229229
)
230-
230+
231231
@classmethod
232232
def get_recipe(cls, model_type):
233233
match model_type:
@@ -338,7 +338,7 @@ def _sanity_check(
338338
def from_params(cls, params_path):
339339
with open(params_path, "r") as f:
340340
loaded_params = json.loads(f.read())
341-
341+
342342
if (model_type_name := loaded_params.get("model_type", None)) is None:
343343
# The model params is in the transformer_args format
344344
# set the model_type to TextOnly and reformat the params
@@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
460460
modules[name] = module_class(**config_args)
461461

462462
return recipe.fusion_class(**modules)
463-
463+
464464
def _replace_known_params(self, params):
465465
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
466466
for key, value in params.items():
467467
if isinstance(value, Hashable) and value in patterns:
468468
params[key] = patterns[value]
469469
return params
470-
470+
471471
@abstractmethod
472472
def forward(self, *args, **kwargs):
473473
raise NotImplementedError("forward method is not implemented")
@@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
939939
self.model_ = exec_lib._load_for_executorch(str(path))
940940

941941
self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])
942-
942+
# TODO: attempt to use "get_max_seq_len" method on the model after
943+
# ExecuTorch bug is fixed.
944+
max_seq_len = 128
945+
# try:
946+
# max_seq_len = self.model_.run_method("get_max_seq_len", [])
947+
# except Exception as e:
948+
# pass
949+
self.text_transformer_args.max_seq_length = max_seq_len
950+
943951
def forward(self, x, input_pos):
944952
# model_.forward expects inputs to be wrapped in a tuple
945953
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
@@ -958,6 +966,6 @@ def forward(self, x, input_pos):
958966

959967
def setup_caches(self, max_batch_size, max_seq_length):
960968
pass
961-
969+
962970
except:
963971
pass

0 commit comments

Comments
 (0)