Skip to content

Commit 4c558cf

Browse files
authored
[Perf] Support topk softmax fused kernel for broader num_experts (#22211)
Signed-off-by: Shixian Cui <[email protected]> Co-authored-by: Shixian Cui <[email protected]>
1 parent 77a6bf0 commit 4c558cf

File tree

2 files changed

+46
-33
lines changed

2 files changed

+46
-33
lines changed

csrc/moe/topk_softmax_kernels.cu

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
188188
It fuses the softmax, max and argmax into a single kernel.
189189
190190
Limitations:
191-
1) This implementation is intended for when the number of experts is a small power of 2.
191+
1) This implementation is optimized for when the number of experts is a small power of 2.
192+
Additionally it also supports when number of experts is multiple of 64 which is still
193+
faster than the computing softmax and topK separately (only tested on CUDA yet).
192194
2) This implementation assumes k is small, but will work for any k.
193195
*/
194196

@@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
198200
int* source_rows, const int k, const int start_expert, const int end_expert)
199201
{
200202
// We begin by enforcing compile time assertions and setting up compile time constants.
201-
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
202-
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
203203
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
204204
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
205205

@@ -407,12 +407,10 @@ struct TopkConstants
407407
};
408408
} // namespace detail
409409

410-
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
410+
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
411411
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
412412
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
413413
{
414-
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
415-
416414
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
417415
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
418416
static constexpr int VPT = Constants::VPT;
@@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
425423
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
426424
}
427425

428-
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
429-
switch (warpSize) { \
430-
case 32: \
431-
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
432-
gating_output, nullptr, topk_weights, topk_indices, \
433-
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
434-
break; \
435-
case 64: \
436-
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
437-
gating_output, nullptr, topk_weights, topk_indices, \
438-
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
439-
break; \
440-
default: \
441-
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
442-
}
426+
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
427+
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
428+
"Unsupported warp size. Only 32 and 64 are supported."); \
429+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
430+
gating_output, nullptr, topk_weights, topk_indices, \
431+
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
443432

444433
template <typename IndType>
445434
void topkGatingSoftmaxKernelLauncher(
@@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
453442
const int topk,
454443
cudaStream_t stream) {
455444
static constexpr int WARPS_PER_TB = 4;
456-
auto warpSize = WARP_SIZE;
445+
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
446+
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
457447
switch (num_experts) {
458448
case 1:
459-
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
449+
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
460450
break;
461451
case 2:
462-
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
452+
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
463453
break;
464454
case 4:
465-
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
455+
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
466456
break;
467457
case 8:
468-
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
458+
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
469459
break;
470460
case 16:
471-
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
461+
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
472462
break;
473463
case 32:
474-
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
464+
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
475465
break;
476466
case 64:
477-
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
467+
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
478468
break;
479469
case 128:
480-
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
470+
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
481471
break;
482472
case 256:
483-
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
473+
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
474+
break;
475+
case 512:
476+
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
477+
break;
478+
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
479+
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
480+
// alternatively we can test 4 bytes loading and enable it in future.
481+
#ifndef USE_ROCM
482+
case 192:
483+
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
484484
break;
485+
case 320:
486+
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
487+
break;
488+
case 384:
489+
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
490+
break;
491+
case 448:
492+
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
493+
break;
494+
case 576:
495+
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
496+
break;
497+
#endif
485498
default: {
486499
TORCH_CHECK(softmax_workspace != nullptr,
487-
"softmax_workspace must be provided for num_experts that are not a power of 2.");
500+
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
488501
static constexpr int TPB = 256;
489502
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
490503
gating_output, nullptr, softmax_workspace, num_experts);

tests/kernels/moe/test_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.platforms import current_platform
3737
from vllm.scalar_type import ScalarType, scalar_types
3838

39-
NUM_EXPERTS = [8, 64]
39+
NUM_EXPERTS = [8, 64, 192]
4040
EP_SIZE = [1, 4]
4141
TOP_KS = [2, 6]
4242

0 commit comments

Comments
 (0)