Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1ded204

Browse files
authored
Enable sdpa backends for server export in export.py
FLASH worked for dso models, so try this with methodical tests
1 parent f4ae60f commit 1ded204

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

torchchat/export.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,14 @@ def main(args):
450450
print(
451451
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
452452
)
453-
export_for_server(
454-
model_to_dso,
455-
builder_args.device,
456-
output_dso_path,
457-
builder_args.dynamic_shapes,
458-
package=False,
459-
)
453+
with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend]):
454+
export_for_server(
455+
model_to_dso,
456+
builder_args.device,
457+
output_dso_path,
458+
builder_args.dynamic_shapes,
459+
package=False,
460+
)
460461

461462
if output_aoti_package_path:
462463
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
@@ -472,11 +473,12 @@ def main(args):
472473
print(
473474
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
474475
)
475-
export_for_server(
476-
model_to_aoti_package,
477-
builder_args.device,
478-
output_aoti_package_path,
479-
builder_args.dynamic_shapes,
480-
package=True,
481-
metadata=metadata,
482-
)
476+
with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend]):
477+
export_for_server(
478+
model_to_aoti_package,
479+
builder_args.device,
480+
output_aoti_package_path,
481+
builder_args.dynamic_shapes,
482+
package=True,
483+
metadata=metadata,
484+
)

0 commit comments

Comments
 (0)