Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 28914fd

Browse files
authored
add ConvertToLinear, disable custom SDPA for bfloat16 (#1189)
1 parent 2281c37 commit 28914fd

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

torchchat/export.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
"""
31-
Export for Server
31+
Export for Server
3232
"""
3333

3434

@@ -78,7 +78,7 @@ def export_for_server(
7878
"""
7979
Export for ExecuTorch
8080
81-
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
81+
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
8282
replace_attention_with_custom_sdpa_attention with ET's implementation
8383
"""
8484

@@ -94,6 +94,9 @@ def export_for_server(
9494
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
9595
XnnpackDynamicallyQuantizedPartitioner,
9696
)
97+
from executorch.backends.xnnpack.passes.convert_to_linear import (
98+
ConvertToLinearPass,
99+
)
97100
from executorch.exir import EdgeProgramManager, to_edge
98101

99102
from executorch.exir.capture._config import (
@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277
_skip_type_promotion=bool(target_precision == torch.float16),
275278
)
276279

277-
if target_precision == torch.float16 or target_precision == torch.bfloat16:
278-
if state_dict_dtype != torch.float16:
279-
print("model.to torch.float16")
280-
model = model.to(dtype=torch.float16)
281-
state_dict_dtype = torch.float16
282-
elif target_precision == torch.float32:
283-
if state_dict_dtype != torch.float32:
284-
print("model.to torch.float32")
285-
model = model.to(dtype=torch.float32)
286-
elif target_precision == torch.bfloat16:
287-
print("model.to torch.bfloat16")
288-
model = model.to(dtype=torch.bfloat16)
289-
else:
280+
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
290281
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
291282

292-
replace_attention_with_custom_sdpa_attention(model)
283+
if state_dict_dtype != target_precision:
284+
print(f"model.to {target_precision}")
285+
model = model.to(dtype=target_precision)
286+
state_dict_dtype = target_precision
287+
288+
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289+
# support anything but bfloat32, and our attempt to use it anyway by converting
290+
# to and from float causes other errors.)
291+
if target_precision != torch.bfloat16:
292+
replace_attention_with_custom_sdpa_attention(model)
293+
293294
with torch.nn.attention.sdpa_kernel(
294295
[torch.nn.attention.SDPBackend.MATH]
295296
), torch.no_grad():
@@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
306307
ExecutorchBackendConfig(
307308
extract_delegate_segments=True,
308309
passes=[
310+
ConvertToLinearPass(),
309311
QuantFusionPass(),
310312
],
311313
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),

0 commit comments

Comments
 (0)