diff --git a/torchchat/export.py b/torchchat/export.py index e387d3dac..592f65665 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -193,7 +193,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None): input_pos[-1].item(), seqlen, ) - output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype) + output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype) return self.wo(output) def replace_attention_with_custom_sdpa_attention(module: nn.Module): @@ -285,11 +285,7 @@ def export_for_et(model, device, output_path) -> str: model = model.to(dtype=target_precision) state_dict_dtype = target_precision - # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't - # support anything but bfloat32, and our attempt to use it anyway by converting - # to and from float causes other errors.) - if target_precision != torch.bfloat16: - replace_attention_with_custom_sdpa_attention(model) + replace_attention_with_custom_sdpa_attention(model) with torch.nn.attention.sdpa_kernel( [torch.nn.attention.SDPBackend.MATH]