@@ -14,7 +14,6 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include " mlu_ops_api.h"
17-
1817namespace xllm ::kernel::mlu {
1918
2019torch::Tensor apply_top_k_top_p (const torch::Tensor& logits,
@@ -25,7 +24,7 @@ torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
2524 return logits;
2625 }
2726 torch::Tensor temperature, topk, topp;
28- if (!temperature .defined ()) {
27+ if (!temperature_list .defined ()) {
2928 temperature =
3029 torch::ones ({logits.size (0 )},
3130 torch::dtype (torch::kFloat32 ).device (logits.device ()));
@@ -57,14 +56,17 @@ torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
5756 {logits.size (0 )}, torch::dtype (torch::kInt32 ).device (logits.device ()));
5857
5958 // Special case handling
59+ // Create a variable to hold the logits to use (may be modified in special
60+ // case)
61+ torch::Tensor logits_for_kernel = logits;
6062 if (!topk_list.defined () && topp_list.defined ()) {
61- auto topk_result = torch::topk (logits, logits.size (1 ));
62- auto topk_logits = std::get<0 >(topk_result);
63+ auto topk_result = torch::topk (logits, logits.size (- 1 ));
64+ logits_for_kernel = std::get<0 >(topk_result);
6365 auto topk_indices = std::get<1 >(topk_result);
6466 index_out = topk_indices.to (torch::kInt32 );
6567 }
6668
67- tmo::torch_api::apply_topkp_v2 (logits .to (torch::kFloat32 ),
69+ tmo::torch_api::apply_topkp_v2 (logits_for_kernel .to (torch::kFloat32 ),
6870 index_in,
6971 temperature,
7072 /* min_topp=*/ torch::Tensor (),
0 commit comments