Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f343b43

Browse files
authored
Revert "export.py: fix custom SDPA type conversion logic & re-enable for bflo…" (#1197)
This reverts commit 3aba730.
1 parent 3aba730 commit f343b43

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchchat/export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
199199
input_pos[-1].item(),
200200
seqlen,
201201
)
202-
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
202+
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
203203
return self.wo(output)
204204

205205
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
@@ -291,7 +291,11 @@ def export_for_et(model, device, output_path) -> str:
291291
model = model.to(dtype=target_precision)
292292
state_dict_dtype = target_precision
293293

294-
replace_attention_with_custom_sdpa_attention(model)
294+
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
295+
# support anything but bfloat32, and our attempt to use it anyway by converting
296+
# to and from float causes other errors.)
297+
if target_precision != torch.bfloat16:
298+
replace_attention_with_custom_sdpa_attention(model)
295299

296300
with torch.nn.attention.sdpa_kernel(
297301
[torch.nn.attention.SDPBackend.MATH]

0 commit comments

Comments
 (0)