Skip to content

Commit 2a30f38

Browse files
committed
kernel: add permutation with align block
1 parent a0e9088 commit 2a30f38

File tree

4 files changed

+296
-177
lines changed

4 files changed

+296
-177
lines changed

src/kernels/dispatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ namespace llm::kernel {
99
#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
1010
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
1111

12+
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
13+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)
14+
15+
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
16+
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
17+
1218
// NOLINTEND(cppcoreguidelines-macro-usage)
1319

1420
} // namespace llm::kernel

src/kernels/moe/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ include(cc_library)
22
include(cc_test)
33

44
cc_library(
5-
NAME
5+
NAME
66
moe.kernels
7-
SRCS
7+
SRCS
88
topk_softmax_kernel.cu
99
grouped_topk_sigmoid_kernel.cu
1010
permutation_index_kernel.cu
1111
permutation_mask_kernel.cu
12+
permutation_align_block_kernel.cu
1213
DEPS
1314
cutlass
1415
glog::glog
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/torch.h>
3+
4+
#include <cub/cub.cuh>
5+
#include <cute/config.hpp>
6+
#include <cute/numeric/numeric_types.hpp>
7+
#include <cute/tensor.hpp>
8+
9+
#include "../dispatch.h"
10+
#include "cute/int_tuple.hpp"
11+
12+
// Adapated from
13+
// https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/moe/moe_align_kernel.cu
14+
15+
// clang-format off
16+
// for exmple: n_tokens = 2, n_experts = 8, topk = 2
17+
// f_idx: idx in flatten indices
18+
// p_idx: idx in permuted tokens
19+
// k_idx: topk idx
20+
// t_idx: token idx
21+
// row_id_map: [topk, n_tokens] => idx in permuted tokens
22+
// ______________________________________________________________________________________
23+
// | | flatten indices | sort indices |
24+
// | Steps | sort by (tokens, topk) | by (experts, tokens) |
25+
// |_________________|_____________________________|______________________________________|
26+
// | | [n_tokens * topk] | [n_tokens * topk] => f_idx |
27+
// | Dim | | f_idx: idx in flatten indices |
28+
// |_________________|_____________________________|______________________________________|
29+
// | | | |
30+
// | top0, top1 | f_idx: | 0 | 1 | 2 | 3 | | p_idx: | 0 | 1 | 2 | 3 | |
31+
// | t0 -> [e2, e1] | experts: | 2 | 1 | 2 | 5 | | f_idx: | 1 | 0 | 2 | 3 | |
32+
// | t1 -> [e2, e5] | tokens: | t0 | t1 | | tokens: | t0 | t0 | t1 | t1 | |
33+
// | | | experts: | e1 | e2 | e5 | |
34+
// | | | |
35+
// | | | |
36+
// |_________________|_____________________________|______________________________________|
37+
// clang-format on
38+
39+
namespace llm::kernel::moe {
40+
41+
namespace {
42+
template <typename T>
43+
inline T* data_ptr(torch::Tensor& t) {
44+
return reinterpret_cast<T*>(t.data_ptr());
45+
}
46+
47+
template <typename T>
48+
inline const T* const_data_ptr(torch::Tensor& t) {
49+
return reinterpret_cast<const T*>(t.const_data_ptr());
50+
}
51+
52+
template <typename scalar_t>
53+
__global__ void count_and_sort_expert_tokens_kernel(
54+
const scalar_t* __restrict__ topk_ids,
55+
int32_t* __restrict__ sorted_token_ids,
56+
int32_t* __restrict__ cumsum_buffer,
57+
size_t numel) {
58+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
59+
const size_t stride = blockDim.x * gridDim.x;
60+
61+
for (size_t i = tid; i < numel; i += stride) {
62+
int32_t expert_id = topk_ids[i];
63+
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
64+
sorted_token_ids[rank_post_pad] = i;
65+
}
66+
}
67+
68+
template <typename scalar_t>
69+
__global__ void moe_align_block_size_kernel(
70+
const scalar_t* __restrict__ topk_ids, // [n_tokens, topk]
71+
int32_t* __restrict__ sorted_token_ids, // [n_permuted_tokens+]
72+
int32_t* __restrict__ expert_ids, // [n_blocks+]
73+
int32_t* __restrict__ total_tokens_post_pad, // [1]
74+
int32_t num_experts,
75+
int32_t padded_num_experts,
76+
int32_t experts_per_warp,
77+
int32_t block_size,
78+
size_t numel,
79+
int32_t* __restrict__ cumsum // [n_experts+1]
80+
) {
81+
constexpr int32_t WARP_SIZE = 32;
82+
// [n_experts+1]
83+
extern __shared__ int32_t shared_counts[];
84+
85+
const int warp_id = threadIdx.x / WARP_SIZE;
86+
const int my_expert_start = warp_id * experts_per_warp;
87+
88+
// init token counts for each thread
89+
for (int i = 0; i < experts_per_warp; ++i) {
90+
if (my_expert_start + i < padded_num_experts) {
91+
shared_counts[warp_id * experts_per_warp + i] = 0;
92+
}
93+
}
94+
95+
__syncthreads();
96+
97+
const size_t tid = threadIdx.x;
98+
const size_t stride = blockDim.x;
99+
100+
// process the token shard
101+
for (size_t i = tid; i < numel; i += stride) {
102+
int expert_id = topk_ids[i];
103+
int warp_idx = expert_id / experts_per_warp;
104+
int expert_offset = expert_id % experts_per_warp;
105+
// accumulate token counts for each expert
106+
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
107+
}
108+
109+
__syncthreads();
110+
111+
if (threadIdx.x == 0) {
112+
cumsum[0] = 0;
113+
for (int i = 1; i <= num_experts; ++i) {
114+
int expert_count = 0;
115+
int warp_idx = (i - 1) / experts_per_warp;
116+
int expert_offset = (i - 1) % experts_per_warp;
117+
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
118+
// why not just expert_count = shared_counts[i - 1]?
119+
120+
cumsum[i] = cumsum[i - 1] + cute::round_up(expert_count, block_size);
121+
}
122+
*total_tokens_post_pad = cumsum[num_experts];
123+
}
124+
125+
__syncthreads();
126+
127+
// update the expert id for each block
128+
if (threadIdx.x < num_experts) {
129+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
130+
i += block_size) {
131+
expert_ids[i / block_size] = threadIdx.x;
132+
}
133+
}
134+
}
135+
136+
template <typename scalar_t>
137+
__global__ void small_align_block_kernel(
138+
const scalar_t* __restrict__ topk_ids, // [n_tokens, topk]
139+
int32_t* __restrict__ sorted_token_ids, // [n_permuted_tokens+]
140+
int32_t* __restrict__ expert_ids, // [n_blocks+]
141+
int32_t* __restrict__ total_tokens_post_pad, // [1]
142+
int32_t num_experts,
143+
int32_t block_size,
144+
size_t numel) {
145+
const size_t tid = threadIdx.x;
146+
const size_t stride = blockDim.x;
147+
148+
//
149+
extern __shared__ int32_t shared_mem[];
150+
// [n_experts+1]
151+
int32_t* cumsum = shared_mem;
152+
// [n_shards+1][n_experts]
153+
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
154+
155+
// init token counts for each expert in the shard
156+
for (int i = 0; i < num_experts; ++i) {
157+
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
158+
}
159+
160+
// calculate expert counts for each token block
161+
for (size_t i = tid; i < numel; i += stride) {
162+
// ++tokens_cnts[threadIdx.x+1][topk_ids[i]];
163+
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
164+
}
165+
166+
__syncthreads();
167+
168+
// calculate the prefix sum of token counts for each expert within the block
169+
if (threadIdx.x < num_experts) {
170+
tokens_cnts[threadIdx.x] = 0;
171+
for (int i = 1; i <= blockDim.x; ++i) {
172+
tokens_cnts[i * num_experts + threadIdx.x] +=
173+
tokens_cnts[(i - 1) * num_experts + threadIdx.x];
174+
}
175+
}
176+
177+
__syncthreads();
178+
179+
// caluculate token counts for each expert
180+
if (threadIdx.x == 0) {
181+
cumsum[0] = 0;
182+
for (int i = 1; i <= num_experts; ++i) {
183+
cumsum[i] = cumsum[i - 1] +
184+
cute::round_up(tokens_cnts[blockDim.x * num_experts + i - 1],
185+
block_size);
186+
}
187+
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
188+
}
189+
190+
__syncthreads();
191+
192+
// each thread fills the expert id for each token
193+
if (threadIdx.x < num_experts) {
194+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
195+
i += block_size) {
196+
expert_ids[i / block_size] = threadIdx.x;
197+
}
198+
}
199+
200+
// each thread process one block
201+
for (size_t i = tid; i < numel; i += stride) {
202+
int32_t expert_id = topk_ids[i];
203+
int32_t rank_post_pad =
204+
tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
205+
sorted_token_ids[rank_post_pad] = i;
206+
++tokens_cnts[threadIdx.x * num_experts + expert_id];
207+
}
208+
}
209+
210+
} // namespace
211+
212+
void permute_align_block(torch::Tensor topk_ids,
213+
int64_t num_experts,
214+
int64_t block_size,
215+
torch::Tensor sorted_token_ids,
216+
torch::Tensor experts_ids,
217+
torch::Tensor num_tokens_post_pad,
218+
torch::Tensor cumsum_buffer) {
219+
constexpr int threads = 1024;
220+
constexpr int32_t WARP_SIZE = 32;
221+
222+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
223+
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "align_block_kernel", [&] {
224+
bool small_batch_expert_mode =
225+
(topk_ids.numel() < 1024) && (num_experts <= 64);
226+
227+
if (small_batch_expert_mode) {
228+
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
229+
const int32_t shared_mem_size =
230+
((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
231+
232+
auto small_batch_expert_kernel = small_align_block_kernel<scalar_t>;
233+
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
234+
topk_ids.data_ptr<scalar_t>(),
235+
sorted_token_ids.data_ptr<int32_t>(),
236+
experts_ids.data_ptr<int32_t>(),
237+
num_tokens_post_pad.data_ptr<int32_t>(),
238+
num_experts,
239+
block_size,
240+
topk_ids.numel());
241+
} else {
242+
// why it is faster?
243+
// use more sms to sort
244+
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
245+
246+
int experts_per_warp = WARP_SIZE;
247+
int64_t padded_num_experts =
248+
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
249+
size_t num_warps = cute::ceil_div(padded_num_experts, experts_per_warp);
250+
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
251+
252+
// can be removed.
253+
// [n_experts+1]
254+
cumsum_buffer.zero_();
255+
// threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
256+
257+
align_kernel<<<1, threads, shared_mem_size, stream>>>(
258+
topk_ids.data_ptr<scalar_t>(),
259+
sorted_token_ids.data_ptr<int32_t>(),
260+
experts_ids.data_ptr<int32_t>(),
261+
num_tokens_post_pad.data_ptr<int32_t>(),
262+
num_experts,
263+
padded_num_experts,
264+
experts_per_warp,
265+
block_size,
266+
topk_ids.numel(),
267+
cumsum_buffer.data_ptr<int32_t>());
268+
269+
// use up to 256 threads to sort
270+
const int block_threads = std::min(256, (int)threads);
271+
// partition permuted tokens into blocks
272+
const int num_blocks =
273+
(topk_ids.numel() + block_threads - 1) / block_threads;
274+
const int max_blocks = 65535;
275+
const int actual_blocks = std::min(num_blocks, max_blocks);
276+
277+
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
278+
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
279+
topk_ids.data_ptr<scalar_t>(),
280+
sorted_token_ids.data_ptr<int32_t>(),
281+
cumsum_buffer.data_ptr<int32_t>(),
282+
topk_ids.numel());
283+
}
284+
});
285+
}
286+
287+
} // namespace llm::kernel::moe

0 commit comments

Comments
 (0)