Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1392868

Browse files
committed
export.py: fix custom SDPA type conversion logic & re-enable for bfloat16
q.dtype is always torch.float at this point. Clearly meant to use the input dtype. ghstack-source-id: a132d7b Pull Request resolved: #1171
1 parent ecd78a5 commit 1392868

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torchchat/export.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
193193
input_pos[-1].item(),
194194
seqlen,
195195
)
196-
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
196+
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
197197
return self.wo(output)
198198

199199
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
@@ -285,11 +285,7 @@ def export_for_et(model, device, output_path) -> str:
285285
model = model.to(dtype=target_precision)
286286
state_dict_dtype = target_precision
287287

288-
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289-
# support anything but bfloat32, and our attempt to use it anyway by converting
290-
# to and from float causes other errors.)
291-
if target_precision != torch.bfloat16:
292-
replace_attention_with_custom_sdpa_attention(model)
288+
replace_attention_with_custom_sdpa_attention(model)
293289

294290
with torch.nn.attention.sdpa_kernel(
295291
[torch.nn.attention.SDPBackend.MATH]

0 commit comments

Comments
 (0)