@@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) {
212212                             .to (torch::dtype (torch::kFloat32 ))
213213                             .to (torch::kCPU );
214214  auto  logits = result[0 ].data_ptr ();
215+   memcpy (s->logits , logits, p->vocab_size  * sizeof (float ));
215216#else  //  __ET_MODEL__
216217  TensorPtr pos_managed = make_tensor_ptr ({1 }, pos_buffer, ScalarType::Long);
217218  TensorPtr tokens_managed = make_tensor_ptr ({1 , 1 }, token_buffer, ScalarType::Long);
@@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) {
228229    exit (EXIT_FAILURE);
229230  }
230231  std::vector<EValue> result = outputs_res.get ();
231-   auto  logits = result[0 ].toTensor ().const_data_ptr ();
232+   //  HACK: the rest of this runner assumes that logits must be float,
233+   //  so we simply convert them rather than plumbing
234+   //  templating/switch-on-type through the rest of this file.
235+   const  auto & result_tensor = result[0 ].toTensor ();
236+   ET_SWITCH_REALHBBF16_TYPES (
237+       result_tensor.scalar_type (),
238+       unused,
239+       " forward" 
240+       CTYPE,
241+       [&]() {
242+         const  CTYPE* logits = result_tensor.const_data_ptr <CTYPE>();
243+         std::transform (logits, logits + p->vocab_size , s->logits , [](auto  x) {
244+           return  static_cast <float >(x);
245+         });
246+       });
232247#endif 
233248
234-   memcpy (s->logits , logits, p->vocab_size  * sizeof (float ));
235249  return  s->logits ;
236250}
237251
0 commit comments