Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


"""
Export for Server
Export for Server
"""


Expand Down Expand Up @@ -78,7 +78,7 @@ def export_for_server(
"""
Export for ExecuTorch

TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
replace_attention_with_custom_sdpa_attention with ET's implementation
"""

Expand All @@ -94,6 +94,9 @@ def export_for_server(
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack.passes.convert_to_linear import (
ConvertToLinearPass,
)
from executorch.exir import EdgeProgramManager, to_edge

from executorch.exir.capture._config import (
Expand Down Expand Up @@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
_skip_type_promotion=bool(target_precision == torch.float16),
)

if target_precision == torch.float16 or target_precision == torch.bfloat16:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow I'm shocked that we let this live as long as it has. Thanks for the fix

if state_dict_dtype != torch.float16:
print("model.to torch.float16")
model = model.to(dtype=torch.float16)
state_dict_dtype = torch.float16
elif target_precision == torch.float32:
if state_dict_dtype != torch.float32:
print("model.to torch.float32")
model = model.to(dtype=torch.float32)
elif target_precision == torch.bfloat16:
print("model.to torch.bfloat16")
model = model.to(dtype=torch.bfloat16)
else:
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")

replace_attention_with_custom_sdpa_attention(model)
if state_dict_dtype != target_precision:
print(f"model.to {target_precision}")
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
), torch.no_grad():
Expand All @@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
ConvertToLinearPass(),
QuantFusionPass(),
],
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
Expand Down
Loading