diff --git a/runner/run.cpp b/runner/run.cpp index b90ca4e81..e161c029e 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) { .to(torch::dtype(torch::kFloat32)) .to(torch::kCPU); auto logits = result[0].data_ptr(); + memcpy(s->logits, logits, p->vocab_size * sizeof(float)); #else // __ET_MODEL__ TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long); TensorPtr tokens_managed = make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long); @@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) { exit(EXIT_FAILURE); } std::vector result = outputs_res.get(); - auto logits = result[0].toTensor().const_data_ptr(); + // HACK: the rest of this runner assumes that logits must be float, + // so we simply convert them rather than plumbing + // templating/switch-on-type through the rest of this file. + const auto& result_tensor = result[0].toTensor(); + ET_SWITCH_REALHBBF16_TYPES( + result_tensor.scalar_type(), + unused, + "forward", + CTYPE, + [&]() { + const CTYPE* logits = result_tensor.const_data_ptr(); + std::transform(logits, logits + p->vocab_size, s->logits, [](auto x) { + return static_cast(x); + }); + }); #endif - memcpy(s->logits, logits, p->vocab_size * sizeof(float)); return s->logits; } diff --git a/torchchat/export.py b/torchchat/export.py index 3867ef319..c024e9deb 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0): input_pos[-1].item(), seqlen, ) - output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype) + output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype) return self.wo(output) def replace_attention_with_custom_sdpa_attention(module: nn.Module): @@ -291,11 +291,7 @@ def export_for_et(model, device, output_path) -> str: model = model.to(dtype=target_precision) state_dict_dtype = target_precision - # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't - # support anything but bfloat32, and our attempt to use it anyway by converting - # to and from float causes other errors.) - if target_precision != torch.bfloat16: - replace_attention_with_custom_sdpa_attention(model) + replace_attention_with_custom_sdpa_attention(model) with torch.nn.attention.sdpa_kernel( [torch.nn.attention.SDPBackend.MATH]