Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit cde5018

Browse files
committed
Update
[ghstack-poisoned]
2 parents d289419 + 5c8a35c commit cde5018

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

runner/run.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) {
212212
.to(torch::dtype(torch::kFloat32))
213213
.to(torch::kCPU);
214214
auto logits = result[0].data_ptr();
215+
memcpy(s->logits, logits, p->vocab_size * sizeof(float));
215216
#else // __ET_MODEL__
216217
TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long);
217218
TensorPtr tokens_managed = make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long);
@@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) {
228229
exit(EXIT_FAILURE);
229230
}
230231
std::vector<EValue> result = outputs_res.get();
231-
auto logits = result[0].toTensor().const_data_ptr();
232+
// HACK: the rest of this runner assumes that logits must be float,
233+
// so we simply convert them rather than plumbing
234+
// templating/switch-on-type through the rest of this file.
235+
const auto& result_tensor = result[0].toTensor();
236+
ET_SWITCH_REALHBBF16_TYPES(
237+
result_tensor.scalar_type(),
238+
unused,
239+
"forward",
240+
CTYPE,
241+
[&]() {
242+
const CTYPE* logits = result_tensor.const_data_ptr<CTYPE>();
243+
std::transform(logits, logits + p->vocab_size, s->logits, [](auto x) {
244+
return static_cast<float>(x);
245+
});
246+
});
232247
#endif
233248

234-
memcpy(s->logits, logits, p->vocab_size * sizeof(float));
235249
return s->logits;
236250
}
237251

0 commit comments

Comments
 (0)