diff --git a/torchchat/export.py b/torchchat/export.py index c024e9deb..3867ef319 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0): input_pos[-1].item(), seqlen, ) - output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype) + output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype) return self.wo(output) def replace_attention_with_custom_sdpa_attention(module: nn.Module): @@ -291,7 +291,11 @@ def export_for_et(model, device, output_path) -> str: model = model.to(dtype=target_precision) state_dict_dtype = target_precision - replace_attention_with_custom_sdpa_attention(model) + # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't + # support anything but bfloat32, and our attempt to use it anyway by converting + # to and from float causes other errors.) + if target_precision != torch.bfloat16: + replace_attention_with_custom_sdpa_attention(model) with torch.nn.attention.sdpa_kernel( [torch.nn.attention.SDPBackend.MATH]