From fce0c799c6cbdf61ac9dd8b6326134ddba1cd78c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 24 Sep 2024 16:29:15 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- runner/run.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/runner/run.cpp b/runner/run.cpp index b90ca4e81..e161c029e 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) { .to(torch::dtype(torch::kFloat32)) .to(torch::kCPU); auto logits = result[0].data_ptr(); + memcpy(s->logits, logits, p->vocab_size * sizeof(float)); #else // __ET_MODEL__ TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long); TensorPtr tokens_managed = make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long); @@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) { exit(EXIT_FAILURE); } std::vector result = outputs_res.get(); - auto logits = result[0].toTensor().const_data_ptr(); + // HACK: the rest of this runner assumes that logits must be float, + // so we simply convert them rather than plumbing + // templating/switch-on-type through the rest of this file. + const auto& result_tensor = result[0].toTensor(); + ET_SWITCH_REALHBBF16_TYPES( + result_tensor.scalar_type(), + unused, + "forward", + CTYPE, + [&]() { + const CTYPE* logits = result_tensor.const_data_ptr(); + std::transform(logits, logits + p->vocab_size, s->logits, [](auto x) { + return static_cast(x); + }); + }); #endif - memcpy(s->logits, logits, p->vocab_size * sizeof(float)); return s->logits; }