From 027b0d8a76eb8a34251a821e9ac5128d99d7f3e4 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 6 Jan 2026 20:09:12 +0000 Subject: [PATCH 1/6] optimize cutlass moe Signed-off-by: yewentao256 --- csrc/ops.h | 5 + .../quantization/w8a8/cutlass/moe/moe_data.cu | 93 +++++++++++++++++++ .../w8a8/cutlass/scaled_mm_entry.cu | 24 +++++ csrc/torch_bindings.cpp | 11 +++ vllm/_custom_ops.py | 19 ++++ .../layers/fused_moe/cutlass_moe.py | 59 +++--------- 6 files changed, 166 insertions(+), 45 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 37e3aaf7499d..b07ed1cfe6f4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -265,6 +265,11 @@ void get_cutlass_moe_mm_problem_sizes( const int64_t k, const std::optional& blockscale_offsets, std::optional force_swap_ab = std::nullopt); +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + const torch::Tensor& expert_first_token_offset, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + const int64_t n, const int64_t k, const bool swap_ab); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 99fec8fd6feb..f58968249991 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -153,6 +153,99 @@ void get_cutlass_moe_mm_problem_sizes_caller( may_swap_ab); } +template +__global__ void compute_problem_sizes_from_expert_offsets( + const int64_t* __restrict__ expert_first_token_offset, + int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2, + const int num_experts, const int n, const int k) { + int const expert_id = blockIdx.x * blockDim.x + threadIdx.x; + if (expert_id >= num_experts) { + return; + } + + int64_t const m64 = expert_first_token_offset[expert_id + 1] - + expert_first_token_offset[expert_id]; + int32_t const m = static_cast(m64); + + int32_t* ps1 = problem_sizes1 + expert_id * 3; + int32_t* ps2 = problem_sizes2 + expert_id * 3; + + if constexpr (!SWAP_AB) { + // [M, 2*N, K] + ps1[0] = m; + ps1[1] = 2 * n; + ps1[2] = k; + // [M, K, N] + ps2[0] = m; + ps2[1] = k; + ps2[2] = n; + } else { + // swap logical M/N in the problem shape + // [2*N, M, K] + ps1[0] = 2 * n; + ps1[1] = m; + ps1[2] = k; + // [K, M, N] + ps2[0] = k; + ps2[1] = m; + ps2[2] = n; + } +} + +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( + const torch::Tensor& expert_first_token_offset, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + const int64_t n, const int64_t k, const bool swap_ab) { + TORCH_CHECK(expert_first_token_offset.is_cuda(), + "expert_first_token_offset must be a CUDA tensor"); + TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64, + "expert_first_token_offset must be int64"); + + TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(), + "problem_sizes must be CUDA tensors"); + TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 && + problem_sizes2.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(), + "problem_sizes must be contiguous"); + TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2, + "problem_sizes must be 2D tensors"); + TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3, + "problem_sizes second dim must be 3"); + TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(), + "problem_sizes1 and problem_sizes2 must have same shape"); + + int64_t const num_experts64 = problem_sizes1.size(0); + TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1, + "expert_first_token_offset must have num_experts + 1 elements"); + TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32"); + TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32"); + + int const num_experts = static_cast(num_experts64); + auto stream = at::cuda::getCurrentCUDAStream( + expert_first_token_offset.device().index()); + + int const threads = 256; + int const blocks = (num_experts + threads - 1) / threads; + + auto const* offsets_ptr = + static_cast(expert_first_token_offset.data_ptr()); + auto* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + auto* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + + if (swap_ab) { + compute_problem_sizes_from_expert_offsets + <<>>(offsets_ptr, ps1_ptr, ps2_ptr, + num_experts, static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes_from_expert_offsets + <<>>(offsets_ptr, ps1_ptr, ps2_ptr, + num_experts, static_cast(n), + static_cast(k)); + } +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 5de21cfbbaaf..077966a1d92a 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -83,6 +83,11 @@ void get_cutlass_moe_mm_problem_sizes_caller( const int64_t k, const std::optional& blockscale_offsets, std::optional force_swap_ab = std::nullopt); +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( + const torch::Tensor& expert_first_token_offset, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + const int64_t n, const int64_t k, const bool swap_ab); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -322,6 +327,25 @@ void get_cutlass_moe_mm_problem_sizes( version_num, ". Required capability: 90, 100, or 120"); } +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + const torch::Tensor& expert_first_token_offset, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + const int64_t n, const int64_t k, const bool swap_ab) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) + get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( + expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: " + "no cutlass_scaled_mm kernel for CUDA device capability: ", + version_num, ". Required capability: 90, 100, or 120"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6f2c8e915b5c..c8cc9204fff6 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -487,6 +487,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, &get_cutlass_moe_mm_problem_sizes); + // compute per-expert problem sizes from expert_first_token_offset + // produced by vLLM's moe_permute kernel + ops.def( + "get_cutlass_moe_mm_problem_sizes_from_expert_offsets(" + " Tensor expert_first_token_offset, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int n, int k, bool swap_ab) -> ()"); + ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes_from_expert_offsets); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0d6d545fed51..5e6e5bcdb570 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1075,6 +1075,25 @@ def get_cutlass_moe_mm_problem_sizes( ) +def get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + expert_first_token_offset: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + n: int, + k: int, + swap_ab: bool, +): + """Compute per-expert (M, N, K) problem sizes from expert_first_token_offset""" + return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + expert_first_token_offset, + problem_sizes1, + problem_sizes2, + n, + k, + swap_ab, + ) + + def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c585cbc1ab5d..dbf2c08c0282 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -108,15 +108,7 @@ def run_cutlass_moe_fp8( assert global_num_experts != -1 assert a1q_scale is not None - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where( - expert_map[topk_ids] != -1, expert_map[topk_ids], -1 - ) - else: - local_topk_ids = topk_ids - - topk = local_topk_ids.size(1) + topk = topk_ids.size(1) local_E = w1.size(0) if use_batched_format: @@ -164,12 +156,8 @@ def run_cutlass_moe_fp8( # during offset calculations expert_offsets = expert_offsets.to(torch.int64) else: - problem_sizes1 = torch.empty( - (global_num_experts, 3), dtype=torch.int32, device=device - ) - problem_sizes2 = torch.empty( - (global_num_experts, 3), dtype=torch.int32, device=device - ) + problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) num_expert = global_num_experts if expert_map is None else expert_map.size(0) # permuted a1q reuses workspace2 @@ -182,11 +170,11 @@ def run_cutlass_moe_fp8( expert_map, permuted_hidden_states=a1q_perm, ) - expert_offsets = expert_first_token_offset[:-1] - - ops.get_cutlass_moe_mm_problem_sizes( - local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K + swap_ab = a1q.size(0) <= 64 + ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab ) + expert_offsets = expert_first_token_offset[:-1] if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by @@ -961,15 +949,7 @@ def run_cutlass_moe_w4a8_fp8( f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}" ) - # Translate info from expert_map to topk_ids - if expert_map is not None: - local_topk_ids = torch.where( - expert_map[topk_ids] != -1, expert_map[topk_ids], -1 - ) - else: - local_topk_ids = topk_ids - - topk = local_topk_ids.size(1) + topk = topk_ids.size(1) a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K)) mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) act_out = _resize_cache(workspace2, (M * topk, N)) @@ -979,12 +959,8 @@ def run_cutlass_moe_w4a8_fp8( ) mm2_out = _resize_cache(workspace2, (M * topk, K)) - problem_sizes1 = torch.empty( - (global_num_experts, 3), dtype=torch.int32, device=device - ) - problem_sizes2 = torch.empty( - (global_num_experts, 3), dtype=torch.int32, device=device - ) + problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) num_expert = global_num_experts if expert_map is None else expert_map.size(0) # permuted a1q reuses workspace2 @@ -997,18 +973,11 @@ def run_cutlass_moe_w4a8_fp8( expert_map, permuted_hidden_states=a1q_perm, ) - expert_offsets = expert_first_token_offset[:-1] - - # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape) - ops.get_cutlass_moe_mm_problem_sizes( - local_topk_ids, - problem_sizes1, - problem_sizes2, - global_num_experts, - N, - K, - force_swap_ab=True, + # for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape). + ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True ) + expert_offsets = expert_first_token_offset[:-1] ops.cutlass_w4a8_moe_mm( mm1_out, From 32990585ed794d412b72515e68c663f43ea0159c Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 6 Jan 2026 20:14:31 +0000 Subject: [PATCH 2/6] remove if else Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index dbf2c08c0282..66b80fbdfb7d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -228,9 +228,7 @@ def run_cutlass_moe_fp8( permuted_hidden_states=mm2_out, topk_weights=topk_weights, inv_permuted_idx=inv_perm, - expert_first_token_offset=( - expert_first_token_offset if expert_map is not None else None - ), + expert_first_token_offset=expert_first_token_offset, ) @@ -1024,9 +1022,7 @@ def run_cutlass_moe_w4a8_fp8( permuted_hidden_states=mm2_out, topk_weights=topk_weights, inv_permuted_idx=inv_perm, - expert_first_token_offset=( - expert_first_token_offset if expert_map is not None else None - ), + expert_first_token_offset=expert_first_token_offset, ) From 905f611d99048c8aecf831b609c768ac2d0326a6 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 7 Jan 2026 17:59:33 +0000 Subject: [PATCH 3/6] dispatch bool Signed-off-by: yewentao256 --- .../quantization/w8a8/cutlass/moe/moe_data.cu | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index f58968249991..2a9ef0672824 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -3,6 +3,8 @@ #include #include +#include "dispatch_utils.h" + #include constexpr uint64_t THREADS_PER_EXPERT = 512; @@ -119,17 +121,12 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); - if (swap_ab) { - compute_problem_sizes<<>>( - topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, - static_cast(topk_ids.numel()), static_cast(n), - static_cast(k)); - } else { - compute_problem_sizes<<>>( + VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { + compute_problem_sizes<<>>( topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, static_cast(topk_ids.numel()), static_cast(n), static_cast(k)); - } + }); } } // namespace @@ -233,17 +230,12 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( auto* ps1_ptr = static_cast(problem_sizes1.data_ptr()); auto* ps2_ptr = static_cast(problem_sizes2.data_ptr()); - if (swap_ab) { - compute_problem_sizes_from_expert_offsets + VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { + compute_problem_sizes_from_expert_offsets <<>>(offsets_ptr, ps1_ptr, ps2_ptr, num_experts, static_cast(n), static_cast(k)); - } else { - compute_problem_sizes_from_expert_offsets - <<>>(offsets_ptr, ps1_ptr, ps2_ptr, - num_experts, static_cast(n), - static_cast(k)); - } + }); } void get_cutlass_moe_mm_data_caller( From 45079c4b0951aeb863bb6d4d71878758ce5974bd Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 7 Jan 2026 13:03:37 -0500 Subject: [PATCH 4/6] Update csrc/quantization/w8a8/cutlass/moe/moe_data.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- csrc/quantization/w8a8/cutlass/moe/moe_data.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 2a9ef0672824..fb6bfb07459f 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -225,10 +225,9 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( int const threads = 256; int const blocks = (num_experts + threads - 1) / threads; - auto const* offsets_ptr = - static_cast(expert_first_token_offset.data_ptr()); - auto* ps1_ptr = static_cast(problem_sizes1.data_ptr()); - auto* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + auto const* offsets_ptr = expert_first_token_offset.data_ptr(); + auto* ps1_ptr = problem_sizes1.data_ptr(); + auto* ps2_ptr = problem_sizes2.data_ptr(); VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { compute_problem_sizes_from_expert_offsets From 45dfa3b83f9ec5347ece69c27b3771913450c125 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 7 Jan 2026 18:04:07 +0000 Subject: [PATCH 5/6] using data ptr Signed-off-by: yewentao256 --- csrc/quantization/w8a8/cutlass/moe/moe_data.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index fb6bfb07459f..d62d18f04ccb 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -116,10 +116,10 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, const bool swap_ab) { int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); - int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); - int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); - int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + auto const* topk_ptr = topk_ids.data_ptr(); + auto* ps1_ptr = problem_sizes1.data_ptr(); + auto* ps2_ptr = problem_sizes2.data_ptr(); + auto* atomic_ptr = atomic_buffer.data_ptr(); VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] { compute_problem_sizes<<>>( From 19024818a346f34faae2a97d1294ae740532ec61 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 9 Jan 2026 16:16:01 +0000 Subject: [PATCH 6/6] address comments Signed-off-by: yewentao256 --- csrc/quantization/w8a8/cutlass/moe/moe_data.cu | 2 +- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index d62d18f04ccb..28af2e7d4d80 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -222,7 +222,7 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( auto stream = at::cuda::getCurrentCUDAStream( expert_first_token_offset.device().index()); - int const threads = 256; + int const threads = (num_experts < 256) ? num_experts : 256; int const blocks = (num_experts + threads - 1) / threads; auto const* offsets_ptr = expert_first_token_offset.data_ptr(); diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 628d0f4cfd00..fdac768da8f9 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -170,6 +170,7 @@ def run_cutlass_moe_fp8( expert_map, permuted_hidden_states=a1q_perm, ) + # swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding). swap_ab = a1q.size(0) <= 64 ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab