Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 46e3ab7

Browse files
authored
[AOTI] Add a --dynamic-shapes option to export (#1011)
* [AOTI] Change export to use static shapes Summary: The inputs to model forward are with static shapes, so changing the export call to make sure more Inductor optimizations will take effect down the stream. This change by itself improves average tokens/sec from 29.60 to 33.43 on A100. Some following PRs will provide further perf gains. * Add a dynamic-shapes option for export * Actually add --dynamic-shapes to CLI * Access args.dynamic_shapes correctly
1 parent dea8d60 commit 46e3ab7

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

.ci/scripts/validate.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ function eval_model_sanity_check() {
285285
echo "******** INT4 group-wise quantized (AOTI) *******"
286286
echo "*************************************************"
287287
if [ "$DTYPE" != "float16" ]; then
288-
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
288+
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
289289
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
290290
cat "$MODEL_DIR/output_eval_aoti"
291291
fi;

build/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class BuilderArgs:
4444
use_distributed: bool = False
4545
is_chat_model: bool = False
4646
prefill_possible: bool = False
47+
dynamic_shapes: bool = False
4748

4849
def __post_init__(self):
4950
if self.device is None:
@@ -157,6 +158,7 @@ def from_args(cls, args): # -> BuilderArgs:
157158
setup_caches=(output_dso_path or output_pte_path),
158159
use_distributed=args.distributed,
159160
is_chat_model=is_chat_model,
161+
dynamic_shapes=getattr(args, "dynamic_shapes", False),
160162
)
161163

162164
@classmethod

cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def _add_export_output_path_args(parser) -> None:
185185
default=None,
186186
help="Output to the specified AOT Inductor .dso model file",
187187
)
188+
parser.add_argument(
189+
"--dynamic-shapes",
190+
action="store_true",
191+
help="Call torch.export with dynamic shapes",
192+
)
188193

189194

190195
# Add CLI Args representing user provided exported model files

export.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838

3939
def export_for_server(
40-
model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso"
40+
model: nn.Module,
41+
device: Optional[str] = "cpu",
42+
output_path: str = "model.dso",
43+
dynamic_shapes: bool = False,
4144
) -> str:
4245
"""
4346
Export the model using AOT Compile to get a .dso for server use cases.
@@ -49,16 +52,22 @@ def export_for_server(
4952
Returns:
5053
The path to the exported model.
5154
"""
52-
input = (
53-
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
54-
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
55-
)
55+
if dynamic_shapes:
56+
input = (
57+
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
58+
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
59+
)
5660

57-
seq = Dim("seq", min=1, max=model.config.max_seq_length)
58-
# Specify that the first dimension of each input is that batch size
59-
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
61+
seq = Dim("seq", min=1, max=model.config.max_seq_length)
62+
# Specify that the first dimension of each input is that batch size
63+
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
64+
else:
65+
input = (
66+
torch.tensor([[1]], dtype=torch.int, device=device),
67+
torch.tensor([0], dtype=torch.int, device=device),
68+
)
69+
dynamic_shapes = None
6070

61-
model.to(device)
6271
so = torch._export.aot_compile(
6372
model,
6473
args=input,
@@ -143,7 +152,12 @@ def main(args):
143152
if output_dso_path:
144153
output_dso_path = str(os.path.abspath(output_dso_path))
145154
print(f"Exporting model using AOT Inductor to {output_dso_path}")
146-
export_for_server(model_to_dso, builder_args.device, output_dso_path)
155+
export_for_server(
156+
model_to_dso,
157+
builder_args.device,
158+
output_dso_path,
159+
builder_args.dynamic_shapes,
160+
)
147161

148162

149163
if __name__ == "__main__":

0 commit comments

Comments
 (0)