Skip to content

Commit c6cd5ca

Browse files
authored
[ROCm][Bugfix] Fix compilation error in topk softmax fused kernel (#22819)
Signed-off-by: kliuae <[email protected]>
1 parent df0e0f0 commit c6cd5ca

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

csrc/moe/topk_softmax_kernels.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,12 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
423423
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
424424
}
425425

426+
#ifndef USE_ROCM
426427
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
427-
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
428-
"Unsupported warp size. Only 32 and 64 are supported."); \
428+
static_assert(WARP_SIZE == 32, \
429+
"Unsupported warp size. Only 32 is supported for CUDA"); \
429430
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
430431
gating_output, nullptr, topk_weights, topk_indices, \
431432
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
433+
#else
434+
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
435+
if (WARP_SIZE == 64) { \
436+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
437+
gating_output, nullptr, topk_weights, topk_indices, \
438+
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
439+
} else if (WARP_SIZE == 32) { \
440+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
441+
gating_output, nullptr, topk_weights, topk_indices, \
442+
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
443+
} else { \
444+
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
445+
}
446+
#endif
432447

433448
template <typename IndType>
434449
void topkGatingSoftmaxKernelLauncher(
@@ -443,7 +458,9 @@ void topkGatingSoftmaxKernelLauncher(
443458
cudaStream_t stream) {
444459
static constexpr int WARPS_PER_TB = 4;
445460
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
461+
#ifndef USE_ROCM
446462
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
463+
#endif
447464
switch (num_experts) {
448465
case 1:
449466
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);

0 commit comments

Comments
 (0)