Skip to content

Commit 9b45ec4

Browse files
fix: correct the implementation of when top_k is not defined for mlu device. (jd-opensource#343)
Co-authored-by: phantomlei <[email protected]>
1 parent 3830ad1 commit 9b45ec4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

xllm/core/kernels/mlu/apply_topkp.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "mlu_ops_api.h"
17-
1817
namespace xllm::kernel::mlu {
1918

2019
torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
@@ -25,7 +24,7 @@ torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
2524
return logits;
2625
}
2726
torch::Tensor temperature, topk, topp;
28-
if (!temperature.defined()) {
27+
if (!temperature_list.defined()) {
2928
temperature =
3029
torch::ones({logits.size(0)},
3130
torch::dtype(torch::kFloat32).device(logits.device()));
@@ -57,14 +56,17 @@ torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
5756
{logits.size(0)}, torch::dtype(torch::kInt32).device(logits.device()));
5857

5958
// Special case handling
59+
// Create a variable to hold the logits to use (may be modified in special
60+
// case)
61+
torch::Tensor logits_for_kernel = logits;
6062
if (!topk_list.defined() && topp_list.defined()) {
61-
auto topk_result = torch::topk(logits, logits.size(1));
62-
auto topk_logits = std::get<0>(topk_result);
63+
auto topk_result = torch::topk(logits, logits.size(-1));
64+
logits_for_kernel = std::get<0>(topk_result);
6365
auto topk_indices = std::get<1>(topk_result);
6466
index_out = topk_indices.to(torch::kInt32);
6567
}
6668

67-
tmo::torch_api::apply_topkp_v2(logits.to(torch::kFloat32),
69+
tmo::torch_api::apply_topkp_v2(logits_for_kernel.to(torch::kFloat32),
6870
index_in,
6971
temperature,
7072
/*min_topp=*/torch::Tensor(),

0 commit comments

Comments
 (0)