@@ -423,12 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
423
423
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
424
424
}
425
425
426
+ #ifndef USE_ROCM
426
427
#define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES ) \
427
- static_assert (WARP_SIZE == 32 || WARP_SIZE == 64 , \
428
- " Unsupported warp size. Only 32 and 64 are supported. " ); \
428
+ static_assert (WARP_SIZE == 32 , \
429
+ " Unsupported warp size. Only 32 is supported for CUDA " ); \
429
430
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
430
431
gating_output, nullptr , topk_weights, topk_indices, \
431
432
token_expert_indices, num_tokens, topk, 0 , num_experts, stream);
433
+ #else
434
+ #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES ) \
435
+ if (WARP_SIZE == 64 ) { \
436
+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64 , MAX_BYTES>( \
437
+ gating_output, nullptr , topk_weights, topk_indices, \
438
+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
439
+ } else if (WARP_SIZE == 32 ) { \
440
+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32 , MAX_BYTES>( \
441
+ gating_output, nullptr , topk_weights, topk_indices, \
442
+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
443
+ } else { \
444
+ assert (false && " Unsupported warp size. Only 32 and 64 are supported for ROCm" ); \
445
+ }
446
+ #endif
432
447
433
448
template <typename IndType>
434
449
void topkGatingSoftmaxKernelLauncher (
@@ -443,7 +458,9 @@ void topkGatingSoftmaxKernelLauncher(
443
458
cudaStream_t stream) {
444
459
static constexpr int WARPS_PER_TB = 4 ;
445
460
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16 ;
461
+ #ifndef USE_ROCM
446
462
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8 ;
463
+ #endif
447
464
switch (num_experts) {
448
465
case 1 :
449
466
LAUNCH_SOFTMAX (1 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
0 commit comments