Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 67f678b

Browse files
authored
[AOTI] Set sdpa_kernel context when exporting (#1013)
Summary: This improves average tokens/sec from 33.43 to 72.63 on A100 for AOTI. ``` python3 torchchat.py export llama3 --quantize '{"precision": {"dtype":"bfloat16"}, "executor":{"accelerator":"cuda"}}' --output-dso-path /tmp/model16.so && python3 torchchat.py generate llama3 --dso-path /tmp/model16.so --prompt "Once upon a time," --max-new-tokens 256 --device cuda --num-samples 3 ```
1 parent 46e3ab7 commit 67f678b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

export.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def export_for_server(
6868
)
6969
dynamic_shapes = None
7070

71-
so = torch._export.aot_compile(
72-
model,
73-
args=input,
74-
options={"aot_inductor.output_path": output_path},
75-
dynamic_shapes=dynamic_shapes,
76-
)
71+
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
72+
so = torch._export.aot_compile(
73+
model,
74+
args=input,
75+
options={"aot_inductor.output_path": output_path},
76+
dynamic_shapes=dynamic_shapes,
77+
)
7778
print(f"The generated DSO model can be found at: {so}")
7879
return so
7980

0 commit comments

Comments
 (0)