Skip to content

Commit 961a896

Browse files
shixiancamd-xiaoyu12
authored andcommitted
[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (vllm-project#23045)
Signed-off-by: Shixian Cui <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 2f689df commit 961a896

File tree

15 files changed

+369
-121
lines changed

15 files changed

+369
-121
lines changed

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def bench_run(
8080
a, score, topk, renormalize=False
8181
)
8282

83+
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
84+
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
85+
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
86+
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
87+
8388
def run_triton_moe(
8489
a: torch.Tensor,
8590
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def run_cutlass_moe(
111116
w2: torch.Tensor,
112117
w1_scale: torch.Tensor,
113118
w2_scale: torch.Tensor,
119+
ab_strides1: torch.Tensor,
120+
ab_strides2: torch.Tensor,
121+
c_strides1: torch.Tensor,
122+
c_strides2: torch.Tensor,
114123
topk_weights: torch.Tensor,
115124
topk_ids: torch.Tensor,
116125
per_act_token: bool,
@@ -125,6 +134,10 @@ def run_cutlass_moe(
125134
topk_ids,
126135
w1_scale,
127136
w2_scale,
137+
ab_strides1,
138+
ab_strides2,
139+
c_strides1,
140+
c_strides2,
128141
per_act_token,
129142
a1_scale=None,
130143
)
@@ -136,6 +149,10 @@ def run_cutlass_from_graph(
136149
w2_q: torch.Tensor,
137150
w1_scale: torch.Tensor,
138151
w2_scale: torch.Tensor,
152+
ab_strides1: torch.Tensor,
153+
ab_strides2: torch.Tensor,
154+
c_strides1: torch.Tensor,
155+
c_strides2: torch.Tensor,
139156
topk_weights: torch.Tensor,
140157
topk_ids: torch.Tensor,
141158
):
@@ -150,6 +167,10 @@ def run_cutlass_from_graph(
150167
topk_ids,
151168
w1_scale,
152169
w2_scale,
170+
ab_strides1,
171+
ab_strides2,
172+
c_strides1,
173+
c_strides2,
153174
per_act_token,
154175
a1_scale=None,
155176
)
@@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
194215
w2_q,
195216
w1_scale,
196217
w2_scale,
218+
ab_strides1,
219+
ab_strides2,
220+
c_strides1,
221+
c_strides2,
197222
topk_weights,
198223
topk_ids,
199224
)
@@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
231256
"w1_scale": w1_scale,
232257
"w2_scale": w2_scale,
233258
"per_act_token": per_act_token,
259+
"ab_strides1": ab_strides1,
260+
"ab_strides2": ab_strides2,
261+
"c_strides1": c_strides1,
262+
"c_strides2": c_strides2,
234263
# cuda graph params
235264
"cutlass_graph": cutlass_graph,
236265
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
289318
w2_q,
290319
w1_scale,
291320
w2_scale,
321+
ab_strides1,
322+
ab_strides2,
323+
c_strides1,
324+
c_strides2,
292325
topk_weights,
293326
topk_ids,
294327
per_act_token,
@@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):
297330

298331
results.append(
299332
benchmark.Timer(
300-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
333+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
301334
globals=globals,
302335
label=label,
303336
sub_label=sub_label,

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ void moe_permute(
4545
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
4646
auto permuted_experts_id = torch::empty_like(topk_ids);
4747
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
48-
auto align_expert_first_token_offset =
49-
torch::zeros_like(expert_first_token_offset);
5048

5149
CubKeyValueSorter sorter{};
5250
int64_t* valid_num_ptr = nullptr;
@@ -85,12 +83,14 @@ void moe_permute(
8583
});
8684

8785
// get m_indices and update expert_first_token_offset with align block
88-
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
89-
get_ptr<int64_t>(align_expert_first_token_offset),
90-
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
91-
stream);
86+
// this is only required for DeepGemm and not required for CUTLASS group gemm
9287
if (align_block_size.has_value()) {
93-
// update align_expert_first_token_offset
88+
auto align_expert_first_token_offset =
89+
torch::zeros_like(expert_first_token_offset);
90+
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
91+
get_ptr<int64_t>(align_expert_first_token_offset),
92+
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
93+
stream);
9494
expert_first_token_offset.copy_(align_expert_first_token_offset);
9595
}
9696
}
@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
195195
torch::Tensor& expert_first_token_offset,
196196
torch::Tensor& src_row_id2dst_row_id_map,
197197
torch::Tensor& m_indices) {
198-
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
198+
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
199199
}
200200

201-
void moe_unpermute(const torch::Tensor& input,
202-
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
203-
const torch::Tensor& token_expert_indices,
204-
const std::optional<torch::Tensor>& expert_map,
205-
int64_t n_expert, int64_t n_local_expert, int64_t topk,
206-
const std::optional<int64_t>& align_block_size,
207-
torch::Tensor& permuted_input,
208-
torch::Tensor& expert_first_token_offset,
209-
torch::Tensor& src_row_id2dst_row_id_map,
210-
torch::Tensor& m_indices) {
201+
void moe_unpermute(
202+
const torch::Tensor& permuted_hidden_states,
203+
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
204+
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
205+
torch::Tensor& hidden_states) {
211206
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
212207
}
213208

@@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
224219
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
225220
m.impl("moe_permute", &moe_permute);
226221
m.impl("moe_unpermute", &moe_unpermute);
227-
}
222+
}

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
229229
const int64_t num_experts, const int64_t n, const int64_t k,
230230
const std::optional<torch::Tensor>& blockscale_offsets);
231231

232+
void get_cutlass_moe_mm_problem_sizes(
233+
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
234+
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
235+
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
236+
232237
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
233238
torch::Tensor& problem_sizes1,
234239
torch::Tensor& problem_sizes2,

csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
1212
__global__ void get_group_gemm_starts(
13-
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
13+
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
1414
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
1515
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
1616
ElementAB* b_base_as_int, ElementC* out_base_as_int,
@@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
3434
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
3535
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
3636
<<<1, num_experts, 0, stream>>>( \
37-
static_cast<int32_t*>(expert_offsets.data_ptr()), \
37+
static_cast<int64_t*>(expert_offsets.data_ptr()), \
3838
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
3939
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
4040
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
@@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
6161
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
6262
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
6363
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
64+
// expect int64_t to avoid overflow during offset calculations
65+
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
6466

6567
int num_experts = static_cast<int>(expert_offsets.size(0));
6668
bool per_act_token = a_scales.numel() != 1;

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
104104
}
105105
}
106106

107+
namespace {
108+
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
109+
torch::Tensor& problem_sizes1,
110+
torch::Tensor& problem_sizes2,
111+
torch::Tensor& atomic_buffer,
112+
int64_t num_experts, int64_t n,
113+
int64_t k, cudaStream_t stream,
114+
const bool swap_ab) {
115+
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
116+
117+
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
118+
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
119+
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
120+
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
121+
122+
if (swap_ab) {
123+
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
124+
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
125+
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
126+
static_cast<int>(k));
127+
} else {
128+
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
129+
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
130+
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
131+
static_cast<int>(k));
132+
}
133+
}
134+
} // namespace
135+
136+
void get_cutlass_moe_mm_problem_sizes_caller(
137+
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
138+
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
139+
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
140+
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
141+
auto options_int32 =
142+
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
143+
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
144+
145+
// Swap-AB should be disabled for FP4 path
146+
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
147+
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
148+
149+
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
150+
atomic_buffer, num_experts, n, k, stream,
151+
may_swap_ab);
152+
}
153+
107154
void get_cutlass_moe_mm_data_caller(
108155
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
109156
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
121168
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
122169
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
123170

124-
if (may_swap_ab) {
125-
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
126-
static_cast<const int32_t*>(topk_ids.data_ptr()),
127-
static_cast<int32_t*>(problem_sizes1.data_ptr()),
128-
static_cast<int32_t*>(problem_sizes2.data_ptr()),
129-
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
130-
k);
131-
} else {
132-
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
133-
static_cast<const int32_t*>(topk_ids.data_ptr()),
134-
static_cast<int32_t*>(problem_sizes1.data_ptr()),
135-
static_cast<int32_t*>(problem_sizes2.data_ptr()),
136-
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
137-
k);
138-
}
171+
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
172+
atomic_buffer, num_experts, n, k, stream,
173+
may_swap_ab);
139174

140175
if (blockscale_offsets.has_value()) {
141176
// fp4 path

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
7676
const int64_t num_experts, const int64_t n, const int64_t k,
7777
const std::optional<torch::Tensor>& blockscale_offsets);
7878

79+
void get_cutlass_moe_mm_problem_sizes_caller(
80+
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
81+
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
82+
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
83+
7984
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
8085
torch::Tensor& problem_sizes1,
8186
torch::Tensor& problem_sizes2,
@@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
293298
version_num, ". Required capability: 90 or 100");
294299
}
295300

301+
void get_cutlass_moe_mm_problem_sizes(
302+
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
303+
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
304+
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
305+
int32_t version_num = get_sm_version_num();
306+
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
307+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
308+
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
309+
problem_sizes2, num_experts, n, k,
310+
blockscale_offsets);
311+
return;
312+
#endif
313+
TORCH_CHECK_NOT_IMPLEMENTED(
314+
false,
315+
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
316+
"kernel for CUDA device capability: ",
317+
version_num, ". Required capability: 90 or 100");
318+
}
319+
296320
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
297321
torch::Tensor& problem_sizes1,
298322
torch::Tensor& problem_sizes2,

csrc/torch_bindings.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
440440
{stride_tag});
441441
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
442442

443+
// A function that computes problem sizes for each expert's multiplication
444+
// used by the two mms called from fused MoE operation. It takes topk_ids as
445+
// an input, and computes problem_sizes1 and problem_sizes2 only.
446+
ops.def(
447+
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
448+
" Tensor! problem_sizes1, "
449+
" Tensor! problem_sizes2, "
450+
" int num_experts, int n, int k, "
451+
" Tensor? blockscale_offsets) -> ()",
452+
{stride_tag});
453+
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
454+
&get_cutlass_moe_mm_problem_sizes);
455+
443456
// A function that computes data required to run fused MoE with w8a8 grouped
444457
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
445458
// as an input, and computes expert_offsets (token start indices of each

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
207207
'topk_ids': topk_ids,
208208
'w1_scale': moe_tensors.w1_scale,
209209
'w2_scale': moe_tensors.w2_scale,
210+
'ab_strides1': moe_tensors.ab_strides1,
211+
'ab_strides2': moe_tensors.ab_strides2,
212+
'c_strides1': moe_tensors.c_strides1,
213+
'c_strides2': moe_tensors.c_strides2,
210214
'per_act_token': per_act_token,
211215
'a1_scale': None #moe_tensors.a_scale
212216
}
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
424428
topk_ids[0][1] = 1
425429

426430
workspace13_shape = (m * topk, max(2 * n, k))
427-
workspace2_shape = (m * topk, n)
428-
output_shape = (m * topk, k)
431+
workspace2_shape = (m * topk, max(n, k))
432+
output_shape = (m, k)
429433

430434
workspace13 = torch.empty(prod(workspace13_shape),
431435
device="cuda",
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
440444
expert_map[start:end] = list(range(num_local_experts))
441445
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
442446

447+
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
448+
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
449+
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
450+
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
451+
443452
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
444453
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
445454
torch.float8_e4m3fn,
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
448457
func = lambda output: run_cutlass_moe_fp8(
449458
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
450459
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
451-
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
452-
per_act_token, per_out_channel, False)
460+
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
461+
workspace13, workspace2, None, mt.a.dtype, per_act_token,
462+
per_out_channel, False, topk_weights)
453463

454464
workspace13.random_()
455465
output_random_workspace = torch.empty(output_shape,

tests/kernels/moe/test_moe_permute_unpermute.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
238238
atol=0,
239239
rtol=0)
240240
# check mindice
241-
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
241+
# current kernel usage assumes deepgemm requires align_block_size
242+
# when it's not provided then we don't compute m_indices (for cutlass)
243+
if align_block_size is not None:
244+
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
245+
242246
# check permuted_hidden_states, only valid token
243247
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
244248
permuted_hidden_states[valid_row_idx],

0 commit comments

Comments
 (0)