diff --git a/torchchat/export.py b/torchchat/export.py index 6b06f1df1..e387d3dac 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -364,13 +364,17 @@ def main(args): except: tokenizer = None - if ( - output_dso_path is not None - and builder_args.max_seq_length is None - and not builder_args.dynamic_shapes - ): - print("Setting max_seq_length to 300 for DSO export.") - builder_args.max_seq_length = 300 + if builder_args.max_seq_length is None: + if ( + output_dso_path is not None + and not builder_args.dynamic_shapes + ): + print("Setting max_seq_length to 300 for DSO export.") + builder_args.max_seq_length = 300 + elif output_pte_path is not None: + # The value of 128 was chosen to match the ExecuTorch Llama example setup. + print("Setting max_seq_length to 128 for ExecuTorch export.") + builder_args.max_seq_length = 128 model = _initialize_model( builder_args,