Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
2 changes: 1 addition & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ jobs:
python torchchat.py export stories15M --output-pte-path ./model.pte
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
for dtype in fp32 fp16; do # bf16 needs to be supported
for dtype in fp32 fp16 bf16; do
echo "Testing export + runner with dtype=$dtype"
python torchchat.py export stories15M --dtype $dtype --output-pte-path ./model.pte
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
Expand Down
18 changes: 16 additions & 2 deletions runner/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) {
.to(torch::dtype(torch::kFloat32))
.to(torch::kCPU);
auto logits = result[0].data_ptr();
memcpy(s->logits, logits, p->vocab_size * sizeof(float));
#else // __ET_MODEL__
TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long);
TensorPtr tokens_managed = make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long);
Expand All @@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) {
exit(EXIT_FAILURE);
}
std::vector<EValue> result = outputs_res.get();
auto logits = result[0].toTensor().const_data_ptr();
// HACK: the rest of this runner assumes that logits must be float,
// so we simply convert them rather than plumbing
// templating/switch-on-type through the rest of this file.
const auto& result_tensor = result[0].toTensor();
ET_SWITCH_REALHBBF16_TYPES(
result_tensor.scalar_type(),
unused,
"forward",
CTYPE,
[&]() {
const CTYPE* logits = result_tensor.const_data_ptr<CTYPE>();
std::transform(logits, logits + p->vocab_size, s->logits, [](auto x) {
return static_cast<float>(x);
});
});
#endif

memcpy(s->logits, logits, p->vocab_size * sizeof(float));
return s->logits;
}

Expand Down
8 changes: 2 additions & 6 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
return self.wo(output)

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
Expand Down Expand Up @@ -291,11 +291,7 @@ def export_for_et(model, device, output_path) -> str:
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
Expand Down
Loading