@@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
188
188
It fuses the softmax, max and argmax into a single kernel.
189
189
190
190
Limitations:
191
- 1) This implementation is intended for when the number of experts is a small power of 2.
191
+ 1) This implementation is optimized for when the number of experts is a small power of 2.
192
+ Additionally it also supports when number of experts is multiple of 64 which is still
193
+ faster than the computing softmax and topK separately (only tested on CUDA yet).
192
194
2) This implementation assumes k is small, but will work for any k.
193
195
*/
194
196
@@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
198
200
int * source_rows, const int k, const int start_expert, const int end_expert)
199
201
{
200
202
// We begin by enforcing compile time assertions and setting up compile time constants.
201
- static_assert (VPT == (VPT & -VPT), " VPT must be power of 2" );
202
- static_assert (NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), " NUM_EXPERTS must be power of 2" );
203
203
static_assert (BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), " BYTES_PER_LDG must be power of 2" );
204
204
static_assert (BYTES_PER_LDG <= 16 , " BYTES_PER_LDG must be leq 16" );
205
205
@@ -407,12 +407,10 @@ struct TopkConstants
407
407
};
408
408
} // namespace detail
409
409
410
- template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
410
+ template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
411
411
void topkGatingSoftmaxLauncherHelper (const float * input, const bool * finished, float * output, IndType* indices,
412
412
int * source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
413
413
{
414
- static constexpr std::size_t MAX_BYTES_PER_LDG = 16 ;
415
-
416
414
static constexpr int BYTES_PER_LDG = MIN (MAX_BYTES_PER_LDG, sizeof (float ) * EXPERTS);
417
415
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
418
416
static constexpr int VPT = Constants::VPT;
@@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
425
423
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
426
424
}
427
425
428
- #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB ) \
429
- switch (warpSize ) { \
430
- case 32 : \
431
- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32 >( \
432
- gating_output, nullptr , topk_weights, topk_indices, \
433
- token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
434
- break ; \
435
- case 64 : \
436
- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64 >( \
437
- gating_output, nullptr , topk_weights, topk_indices, \
438
- token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
439
- break ; \
440
- default : \
441
- TORCH_CHECK (false , " Unsupported warp size: " , warpSize ); \
442
- }
426
+ #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES ) \
427
+ static_assert (WARP_SIZE == 32 || WARP_SIZE == 64 , \
428
+ " Unsupported warp size. Only 32 and 64 are supported." ); \
429
+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
430
+ gating_output, nullptr , topk_weights, topk_indices, \
431
+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream);
443
432
444
433
template <typename IndType>
445
434
void topkGatingSoftmaxKernelLauncher (
@@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
453
442
const int topk,
454
443
cudaStream_t stream) {
455
444
static constexpr int WARPS_PER_TB = 4 ;
456
- auto warpSize = WARP_SIZE;
445
+ static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16 ;
446
+ static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8 ;
457
447
switch (num_experts) {
458
448
case 1 :
459
- LAUNCH_SOFTMAX (1 , WARPS_PER_TB);
449
+ LAUNCH_SOFTMAX (1 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
460
450
break ;
461
451
case 2 :
462
- LAUNCH_SOFTMAX (2 , WARPS_PER_TB);
452
+ LAUNCH_SOFTMAX (2 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
463
453
break ;
464
454
case 4 :
465
- LAUNCH_SOFTMAX (4 , WARPS_PER_TB);
455
+ LAUNCH_SOFTMAX (4 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
466
456
break ;
467
457
case 8 :
468
- LAUNCH_SOFTMAX (8 , WARPS_PER_TB);
458
+ LAUNCH_SOFTMAX (8 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
469
459
break ;
470
460
case 16 :
471
- LAUNCH_SOFTMAX (16 , WARPS_PER_TB);
461
+ LAUNCH_SOFTMAX (16 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
472
462
break ;
473
463
case 32 :
474
- LAUNCH_SOFTMAX (32 , WARPS_PER_TB);
464
+ LAUNCH_SOFTMAX (32 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
475
465
break ;
476
466
case 64 :
477
- LAUNCH_SOFTMAX (64 , WARPS_PER_TB);
467
+ LAUNCH_SOFTMAX (64 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
478
468
break ;
479
469
case 128 :
480
- LAUNCH_SOFTMAX (128 , WARPS_PER_TB);
470
+ LAUNCH_SOFTMAX (128 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
481
471
break ;
482
472
case 256 :
483
- LAUNCH_SOFTMAX (256 , WARPS_PER_TB);
473
+ LAUNCH_SOFTMAX (256 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
474
+ break ;
475
+ case 512 :
476
+ LAUNCH_SOFTMAX (512 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
477
+ break ;
478
+ // (CUDA only) support multiples of 64 when num_experts is not power of 2.
479
+ // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
480
+ // alternatively we can test 4 bytes loading and enable it in future.
481
+ #ifndef USE_ROCM
482
+ case 192 :
483
+ LAUNCH_SOFTMAX (192 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
484
484
break ;
485
+ case 320 :
486
+ LAUNCH_SOFTMAX (320 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
487
+ break ;
488
+ case 384 :
489
+ LAUNCH_SOFTMAX (384 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
490
+ break ;
491
+ case 448 :
492
+ LAUNCH_SOFTMAX (448 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
493
+ break ;
494
+ case 576 :
495
+ LAUNCH_SOFTMAX (576 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
496
+ break ;
497
+ #endif
485
498
default : {
486
499
TORCH_CHECK (softmax_workspace != nullptr ,
487
- " softmax_workspace must be provided for num_experts that are not a power of 2." );
500
+ " softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64 ." );
488
501
static constexpr int TPB = 256 ;
489
502
moeSoftmax<TPB><<<num_tokens, TPB, 0 , stream>>> (
490
503
gating_output, nullptr , softmax_workspace, num_experts);
0 commit comments