Skip to content

Commit 63506b0

Browse files
chunyuan-wDiweiSun
andauthored
add topk_softmax kernel (sgl-project#11)
* fix ci repo * typo fix * add topk_softmax kernel * fix format * add test_topk_softmax.py into CI yml * fix format * fix ci branch * use get_device in utils * put topk_softmax under namespace at::native::xpu * fix kNegInfinity; remove unused headers * use div_up * use int64_t instead of int * use AT_DISPATCH_REDUCED_FLOATING_TYPES * refactor data_ptr * remove rows_for_experts and offsets * add a precision function in utils.py * fix lint * add comment for increased tolerance in the UT --------- Co-authored-by: DiweiSun <[email protected]>
1 parent 767eee6 commit 63506b0

File tree

8 files changed

+374
-6
lines changed

8 files changed

+374
-6
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ jobs:
2525

2626
- name: Build Docker image
2727
run: |
28-
docker build --no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:pvc .
29-
28+
docker build \
29+
--build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \
30+
--build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \
31+
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:pvc .
3032
3133
- name: Run container
3234
run: |
@@ -48,7 +50,7 @@ jobs:
4850
timeout-minutes: 20
4951
run: |
5052
docker exec -w /root/sglang ci_sglang_xpu \
51-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py"
53+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py"
5254
5355
- name: Run E2E Bfloat16 tests
5456
timeout-minutes: 20

include/sgl_kernel_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight,
122122
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
123123
void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps);
124124
void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double eps);
125+
void topk_softmax(at::Tensor& topk_weights, at::Tensor& topk_indices, at::Tensor& gating_output, bool renormalize);
125126
} // namespace at::native::xpu
126127
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
127128
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

python/sgl_kernel/moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, Optional, Tuple
22

33
import torch
44

@@ -28,11 +28,11 @@ def moe_align_block_size(
2828
def topk_softmax(
2929
topk_weights: torch.Tensor,
3030
topk_ids: torch.Tensor,
31-
token_expert_indices: torch.Tensor,
3231
gating_output: float,
32+
renormalize: bool = False,
3333
) -> None:
3434
torch.ops.sgl_kernel.topk_softmax.default(
35-
topk_weights, topk_ids, token_expert_indices, gating_output
35+
topk_weights, topk_ids, gating_output, renormalize
3636
)
3737

3838

src/sycl/MoEOps.cpp

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#include <ATen/ATen.h>
2+
#include <c10/xpu/XPUStream.h>
3+
#include <torch/all.h>
4+
5+
#include <sycl/sycl.hpp>
6+
7+
#include "SYCLHelpers.h"
8+
#include "Utils.h"
9+
10+
namespace at::native::xpu {
11+
12+
namespace TopKSoftmaxImpl {
13+
14+
template <typename T>
15+
struct FusedTopkSoftmax {
16+
static constexpr int sub_group_size = 32;
17+
static constexpr int max_group_size = 1024;
18+
static constexpr int malloc_per_item = 8;
19+
static constexpr float kNegInfinity = -std::numeric_limits<float>::infinity();
20+
21+
FusedTopkSoftmax(
22+
float* topk_weights,
23+
int* topk_ids,
24+
const T* gating_output,
25+
const bool renormalize,
26+
const int tokens,
27+
const int experts,
28+
const int top_k)
29+
: topk_weights(topk_weights),
30+
topk_ids(topk_ids),
31+
gating_output(gating_output),
32+
renormalize(renormalize),
33+
tokens(tokens),
34+
experts(experts),
35+
top_k(top_k) {}
36+
37+
static inline sycl::nd_range<3> get_nd_range(const int tokens, const int experts) {
38+
int calc_per_item = div_up(experts, sub_group_size);
39+
int group_size = div_up(experts, calc_per_item);
40+
group_size = group_size < sub_group_size ? sub_group_size : group_size;
41+
group_size = group_size < max_group_size ? group_size : max_group_size;
42+
int sub_groups_per_group = div_up(group_size, sub_group_size);
43+
group_size = sub_groups_per_group * sub_group_size;
44+
int global_size = div_up(tokens, sub_groups_per_group);
45+
46+
sycl::range<3> local(1, 1, group_size);
47+
sycl::range<3> global(1, 1, global_size);
48+
return sycl::nd_range<3>(global * local, local);
49+
}
50+
51+
static inline T Sigmoid(T x) {
52+
float sycl_x = static_cast<float>(x);
53+
float result = 1.0f / (1.0f + sycl::exp(-sycl_x));
54+
return static_cast<T>(result);
55+
}
56+
57+
[[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<3> item) const {
58+
int group_id = item.get_group_linear_id();
59+
int local_range = item.get_local_range(2);
60+
int sub_groups_per_group = local_range / sub_group_size;
61+
int calc_per_item = (experts + sub_group_size - 1) / sub_group_size;
62+
63+
sycl::sub_group sg = item.get_sub_group();
64+
int sg_id = sg.get_group_id();
65+
int sg_local_id = sg.get_local_id();
66+
67+
int tid = group_id * sub_groups_per_group + sg_id;
68+
69+
if (tid >= tokens) {
70+
return; // Out of bounds
71+
}
72+
73+
T local_elems[malloc_per_item];
74+
int local_idx[malloc_per_item];
75+
76+
int start_offset = sg_local_id * calc_per_item;
77+
int local_num = calc_per_item;
78+
79+
if (start_offset + local_num >= experts) {
80+
local_num = experts - start_offset;
81+
if (local_num < 0) {
82+
local_num = 0; // No elements to process
83+
}
84+
}
85+
86+
for (int e = 0; e < calc_per_item; ++e) {
87+
local_elems[e] = kNegInfinity;
88+
local_idx[e] = -1;
89+
}
90+
91+
for (int e = 0; e < local_num; ++e) {
92+
local_elems[e] = gating_output[tid * experts + start_offset + e];
93+
local_idx[e] = start_offset + e;
94+
}
95+
96+
// Perform top-k selection
97+
T topk_weights_local[malloc_per_item];
98+
int topk_ids_local[malloc_per_item];
99+
100+
for (int k = 0; k < top_k; ++k) {
101+
T k_max = kNegInfinity;
102+
int k_max_idx = -1;
103+
int remove_ix = -1;
104+
for (int e = 0; e < calc_per_item; ++e) {
105+
T my_val = local_elems[e];
106+
int my_idx = local_idx[e];
107+
for (int offset = sub_group_size / 2; offset > 0; offset /= 2) {
108+
T other_val = sycl::permute_group_by_xor(sg, my_val, offset);
109+
int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset);
110+
if (other_val > my_val || (other_val == my_val && other_idx < my_idx)) {
111+
my_val = other_val;
112+
my_idx = other_idx;
113+
}
114+
}
115+
if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) {
116+
k_max = my_val;
117+
k_max_idx = my_idx;
118+
119+
if (k_max_idx == local_idx[e]) {
120+
remove_ix = e; // Mark this index for removal
121+
} else
122+
remove_ix = -1;
123+
}
124+
}
125+
topk_weights_local[k] = k_max;
126+
topk_ids_local[k] = k_max_idx;
127+
if (remove_ix != -1) {
128+
// Reset the score to avoid re-selection
129+
local_elems[remove_ix] = kNegInfinity;
130+
local_idx[remove_ix] = -1;
131+
remove_ix = -1;
132+
}
133+
}
134+
135+
float max_score = topk_weights_local[0];
136+
float sum_exp = 0;
137+
138+
for (int i = 0; i < top_k; ++i) {
139+
float score = topk_weights_local[i];
140+
sum_exp += sycl::exp(score - max_score);
141+
}
142+
143+
for (int e = 0; e < calc_per_item; ++e) {
144+
float score = local_elems[e];
145+
float my_val = sycl::exp(score - max_score);
146+
for (int offset = sub_group_size / 2; offset > 0; offset /= 2) {
147+
float other_val = sycl::permute_group_by_xor(sg, my_val, offset);
148+
my_val += other_val;
149+
}
150+
sum_exp += my_val;
151+
}
152+
153+
for (int i = 0; i < top_k; ++i) {
154+
float score = topk_weights_local[i];
155+
topk_weights_local[i] = sycl::exp(score - max_score) / sum_exp;
156+
}
157+
158+
if (renormalize) {
159+
// Renormalize the top-k weights
160+
float sum = 0;
161+
for (int i = 0; i < top_k; ++i) {
162+
sum += topk_weights_local[i];
163+
}
164+
if (sum > 0) {
165+
for (int i = 0; i < top_k; ++i) {
166+
topk_weights_local[i] /= sum;
167+
}
168+
}
169+
}
170+
171+
if (sg_local_id == 0) {
172+
int offset = tid * top_k;
173+
for (int i = 0; i < top_k; ++i) {
174+
topk_weights[offset + i] = topk_weights_local[i];
175+
if (topk_ids_local[i] < 0 || topk_ids_local[i] >= experts) {
176+
// Ensure valid index
177+
topk_ids[offset + i] = 0;
178+
continue;
179+
}
180+
topk_ids[offset + i] = topk_ids_local[i];
181+
}
182+
}
183+
}
184+
float* topk_weights;
185+
int* topk_ids;
186+
const T* gating_output;
187+
const bool renormalize;
188+
const int tokens;
189+
const int experts;
190+
const int top_k;
191+
};
192+
193+
template <typename T>
194+
void launch_fused_topk_softmax(
195+
sycl::queue& queue,
196+
const T* gating_output,
197+
float* topk_weights,
198+
int* topk_indices,
199+
const bool renormalize,
200+
const int top_k,
201+
const int num_tokens,
202+
const int num_experts) {
203+
using Kernel = FusedTopkSoftmax<T>;
204+
auto range = Kernel::get_nd_range(num_tokens, num_experts);
205+
206+
auto global_range = range.get_global_range();
207+
auto local_range = range.get_local_range();
208+
209+
Kernel task(topk_weights, topk_indices, gating_output, renormalize, num_tokens, num_experts, top_k);
210+
211+
sycl_kernel_submit(global_range, local_range, queue, task);
212+
return;
213+
}
214+
215+
template <typename T>
216+
void fused_topk_softmax(
217+
const T* gating_output,
218+
float* topk_weights,
219+
int* topk_indices,
220+
const bool renormalize,
221+
const int num_tokens,
222+
const int num_experts,
223+
const int topk) {
224+
auto stream = at::xpu::getCurrentXPUStream();
225+
auto queue = stream.queue();
226+
227+
launch_fused_topk_softmax(
228+
queue, gating_output, topk_weights, topk_indices, renormalize, topk, num_tokens, num_experts);
229+
};
230+
}; // namespace TopKSoftmaxImpl
231+
232+
/**
233+
* @brief Perform topk after softmax on gating_output.
234+
* @param topk_weights The topk_weights tensor of shape [n_tokens, n_topk].
235+
* @param topk_indices The topk_indices tensor of shape [n_tokens, n_topk].
236+
* @param gating_output The gating output tensor of shape [n_tokens, n_experts].
237+
* @param renormalize The renormalize bool whether the topk_weights needs to be renormalized.
238+
* @return void.
239+
*/
240+
void topk_softmax(at::Tensor& topk_weights, at::Tensor& topk_indices, at::Tensor& gating_output, bool renormalize) {
241+
auto shape = gating_output.sizes().vec();
242+
TORCH_CHECK(shape.size() == 2, "gating_output must be 2D tensor, but got ", shape.size(), "D");
243+
int64_t n_tokens = shape[0];
244+
int64_t n_experts = shape[1];
245+
246+
TORCH_CHECK(n_experts <= 128, "n_experts only support up to 128, but got ", n_experts);
247+
248+
TORCH_CHECK(topk_weights.scalar_type() == at::kFloat, "topk_weights should be Float");
249+
TORCH_CHECK(topk_indices.scalar_type() == at::kInt, "topk_indices should be Int");
250+
251+
constexpr int64_t alignment = 8;
252+
int64_t n_experts_aligned = div_up(n_experts, alignment) * alignment; // align to 8
253+
254+
int64_t n_topk = topk_weights.size(1);
255+
256+
AT_DISPATCH_REDUCED_FLOATING_TYPES(gating_output.scalar_type(), "fused_topk_softmax_kernel", [&]() {
257+
TopKSoftmaxImpl::fused_topk_softmax<scalar_t>(
258+
gating_output.data_ptr<scalar_t>(),
259+
topk_weights.data_ptr<float>(),
260+
topk_indices.data_ptr<int>(),
261+
renormalize,
262+
n_tokens,
263+
n_experts,
264+
n_topk);
265+
});
266+
}
267+
} // namespace at::native::xpu

src/sycl/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,8 @@ int get_min(Func limit_func, int X, Args*... args) {
205205
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
206206
} \
207207
}
208+
209+
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
210+
inline T div_up(T x, T y) {
211+
return (x + y - 1) / y;
212+
}

src/torch_extension_sycl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
4646
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
4747
m.impl("gemma_fused_add_rmsnorm", torch::kXPU, &at::native::xpu::gemma_fused_add_rmsnorm);
4848

49+
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
50+
m.impl("topk_softmax", torch::kXPU, &at::native::xpu::topk_softmax);
51+
4952
// m.def(
5053
// "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype,
5154
// -> Tensor");

0 commit comments

Comments
 (0)