Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3aba730

Browse files
authored
export.py: fix custom SDPA type conversion logic & re-enable for bfloat16 (#1193)
1 parent 6d2ef4a commit 3aba730

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torchchat/export.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
199199
input_pos[-1].item(),
200200
seqlen,
201201
)
202-
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
202+
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
203203
return self.wo(output)
204204

205205
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
@@ -291,11 +291,7 @@ def export_for_et(model, device, output_path) -> str:
291291
model = model.to(dtype=target_precision)
292292
state_dict_dtype = target_precision
293293

294-
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
295-
# support anything but bfloat32, and our attempt to use it anyway by converting
296-
# to and from float causes other errors.)
297-
if target_precision != torch.bfloat16:
298-
replace_attention_with_custom_sdpa_attention(model)
294+
replace_attention_with_custom_sdpa_attention(model)
299295

300296
with torch.nn.attention.sdpa_kernel(
301297
[torch.nn.attention.SDPBackend.MATH]

0 commit comments

Comments
 (0)