Skip to content

Commit f4b4a32

Browse files
committed
Add elementwise kernels.
1 parent e36af02 commit f4b4a32

File tree

6 files changed

+502
-9
lines changed

6 files changed

+502
-9
lines changed

include/sgl_kernel_ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ void sgl_fused_add_rmsnorm(
123123
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
124124
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
125125
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
126-
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
127-
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
128-
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
126+
void silu_and_mul(at::Tensor& out, at::Tensor& input);
127+
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
128+
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
129129
void apply_rope_pos_ids_cos_sin_cache(
130130
at::Tensor q,
131131
at::Tensor k,

python/sgl_kernel/elementwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
179179
device=input.device,
180180
dtype=input.dtype,
181181
)
182-
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
182+
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
183183
return out
184184

185185

@@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
194194
device=input.device,
195195
dtype=input.dtype,
196196
)
197-
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
197+
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
198198
return out
199199

200200

@@ -209,7 +209,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
209209
device=input.device,
210210
dtype=input.dtype,
211211
)
212-
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
212+
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
213213
return out
214214

215215

src/sycl/TripleOps.cpp

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Parallel.h>
3+
#include <c10/xpu/XPUStream.h>
4+
#include <torch/all.h>
5+
#include <ATen/OpMathType.h>
6+
7+
#include <cmath>
8+
#include <cstdint>
9+
#include <iostream>
10+
#include <sycl/sycl.hpp>
11+
#include <vector>
12+
#include "Utils.h"
13+
14+
#include "SYCLHelpers.h"
15+
16+
#define DPCPP_CONSTANT __attribute__((opencl_constant))
17+
18+
#define DPCPP_KER_STRING(var, str) static const DPCPP_CONSTANT char var[] = str;
19+
#define DPCPP_KER_PRINTF sycl::ext::oneapi::experimental::printf
20+
21+
#define DPCPP_K_PRINT(fmt_str, ...) \
22+
{ \
23+
DPCPP_KER_STRING(fmt_var, fmt_str); \
24+
DPCPP_KER_PRINTF(fmt_var, ##__VA_ARGS__); \
25+
}
26+
27+
template <typename scalar_t, int vec_size>
28+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector_loop {
29+
scalar_t val[vec_size];
30+
31+
scalar_t& operator[](int index) {
32+
return val[index];
33+
}
34+
35+
scalar_t const& operator[](int index) const {
36+
return val[index];
37+
}
38+
};
39+
40+
template <typename scalar_t, typename accscalar_t>
41+
struct silu_mul_dpcpp_functor {
42+
scalar_t operator()(scalar_t a, scalar_t b) const {
43+
return (accscalar_t(a)) / (1.0f + expf(accscalar_t(-a))) * accscalar_t(b);
44+
}
45+
};
46+
47+
template <typename scalar_t, typename accscalar_t>
48+
struct gelu_tanh_mul_dpcpp_functor {
49+
scalar_t operator()(scalar_t a, scalar_t b) const {
50+
const accscalar_t kBeta = M_SQRT2 * M_2_SQRTPI * accscalar_t(0.5);
51+
const accscalar_t kKappa = 0.044715;
52+
auto x_cube = accscalar_t(a) * accscalar_t(a) * accscalar_t(a);
53+
auto inner = kBeta * (accscalar_t(a) + kKappa * x_cube);
54+
return (accscalar_t(0.5) * accscalar_t(a) * (accscalar_t(1) + std::tanh(accscalar_t(inner)))) * accscalar_t(b);
55+
}
56+
};
57+
58+
template <typename scalar_t, typename accscalar_t>
59+
struct gelu_erf_mul_dpcpp_functor {
60+
scalar_t operator()(scalar_t a, scalar_t b) const {
61+
return (accscalar_t(a) * accscalar_t(0.5) * (accscalar_t(1) + ::erf(accscalar_t(a) * accscalar_t(M_SQRT1_2)))) * accscalar_t(b);
62+
}
63+
};
64+
65+
template <typename scalar_t, typename func_t, int N>
66+
struct op_and_mul_functor{
67+
void operator()(sycl::nd_item<1> item) const {
68+
using accscalar_t = at::opmath_type<scalar_t>;
69+
int64_t offset = item.get_local_linear_id();
70+
int64_t step = item.get_local_range(0);
71+
int64_t token_id = item.get_group(0);
72+
func_t fn;
73+
int64_t bound = dim / N;
74+
for (int64_t i = offset; i < bound; i += step) {
75+
auto unary_val = reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(input_ptr)[token_id * bound * 2 + i];
76+
auto mul_val = reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(input_ptr)[token_id * bound * 2 + i + bound];
77+
#pragma unroll
78+
for (int i = 0; i < N; ++i) {
79+
auto a = unary_val[i], b = mul_val[i];
80+
unary_val[i] = fn(unary_val[i], mul_val[i]);
81+
}
82+
reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(output_ptr)[token_id * bound + i] = unary_val;
83+
84+
}
85+
}
86+
87+
scalar_t* input_ptr;
88+
scalar_t* output_ptr;
89+
int64_t num_;
90+
int64_t dim;
91+
};
92+
93+
#define VEC_LAUNCH(KERNEL, N) \
94+
case N: { \
95+
op_and_mul_functor<T_to, KERNEL<T_to, accscalar_t>, N> kfn = { \
96+
.input_ptr = _input, \
97+
.output_ptr = _out, \
98+
.num_ = numel, \
99+
.dim = dim \
100+
}; \
101+
sycl_kernel_submit(num_group*wg_size, wg_size, q, kfn); \
102+
break; \
103+
} \
104+
105+
template <typename T = float>
106+
void get_config(
107+
const Tensor& input,
108+
const Tensor& out,
109+
int64_t& numel,
110+
int64_t& dim,
111+
int64_t& wg_size,
112+
int64_t& num_group,
113+
int& vec_size) {
114+
auto dev_id = torch_ipex::xpu::dpcpp::dpcppGetDeviceIdOfCurrentQueue();
115+
int64_t max_wg_size = torch_ipex::xpu::dpcpp::dpcppMaxWorkGroupSize(dev_id);
116+
numel = out.numel();
117+
dim = out.size(-1);
118+
int64_t tokens = numel/dim;
119+
wg_size = std::min(dim, max_wg_size);
120+
num_group = tokens;
121+
122+
vec_size = sizeof(float) * 4 / sizeof(T);
123+
while ((vec_size >> 1) * wg_size >= dim) {
124+
vec_size = vec_size >> 1;
125+
}
126+
if (dim % vec_size != 0)
127+
vec_size = 1;
128+
}
129+
130+
template <typename T_to = float, typename T_from = float>
131+
void silu_and_mul_sycl(
132+
sycl::queue& q,
133+
Tensor& input,
134+
Tensor& out) {
135+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
136+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
137+
138+
int64_t numel;
139+
int64_t dim;
140+
int64_t wg_size;
141+
int64_t num_group;
142+
int vec_size;
143+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
144+
145+
using accscalar_t = at::opmath_type<T_to>;
146+
switch (vec_size) {
147+
VEC_LAUNCH(silu_mul_dpcpp_functor, 1);
148+
VEC_LAUNCH(silu_mul_dpcpp_functor, 2);
149+
VEC_LAUNCH(silu_mul_dpcpp_functor, 4);
150+
VEC_LAUNCH(silu_mul_dpcpp_functor, 8);
151+
VEC_LAUNCH(silu_mul_dpcpp_functor, 16);
152+
default:
153+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
154+
}
155+
156+
return;
157+
}
158+
159+
void silu_and_mul(Tensor& out, Tensor& input) {
160+
input = input.contiguous();
161+
out = out.contiguous();
162+
163+
auto stream = at::xpu::getCurrentXPUStream();
164+
auto queue = stream.queue();
165+
166+
if (input.scalar_type() == at::ScalarType::Half) {
167+
silu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
168+
} else {
169+
silu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
170+
queue, input, out
171+
);
172+
}
173+
return;
174+
}
175+
176+
template<typename T_to = float, typename T_from = float>
177+
void gelu_tanh_and_mul_sycl(
178+
sycl::queue& q,
179+
Tensor& input,
180+
Tensor& out) {
181+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
182+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
183+
184+
int64_t numel;
185+
int64_t dim;
186+
int64_t wg_size;
187+
int64_t num_group;
188+
int vec_size;
189+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
190+
191+
using accscalar_t = at::opmath_type<T_to>;
192+
switch (vec_size) {
193+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 1);
194+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 2);
195+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 4);
196+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 8);
197+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 16);
198+
default:
199+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
200+
}
201+
202+
return;
203+
}
204+
205+
void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
206+
input = input.contiguous();
207+
out = out.contiguous();
208+
209+
auto stream = at::xpu::getCurrentXPUStream();
210+
auto queue = stream.queue();
211+
212+
if (input.scalar_type() == at::ScalarType::Half) {
213+
gelu_tanh_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
214+
} else {
215+
gelu_tanh_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
216+
queue, input, out
217+
);
218+
}
219+
return;
220+
}
221+
222+
template<typename T_to = float, typename T_from = float>
223+
void gelu_and_mul_sycl(
224+
sycl::queue& q,
225+
Tensor& input,
226+
Tensor& out) {
227+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
228+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
229+
230+
int64_t numel;
231+
int64_t dim;
232+
int64_t wg_size;
233+
int64_t num_group;
234+
int vec_size;
235+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
236+
237+
using accscalar_t = at::opmath_type<T_to>;
238+
switch (vec_size) {
239+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 1);
240+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 2);
241+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 4);
242+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 8);
243+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 16);
244+
default:
245+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
246+
}
247+
248+
return;
249+
}
250+
251+
252+
void gelu_and_mul(Tensor& out, Tensor& input) {
253+
input = input.contiguous();
254+
out = out.contiguous();
255+
256+
auto stream = at::xpu::getCurrentXPUStream();
257+
auto queue = stream.queue();
258+
259+
if (input.scalar_type() == at::ScalarType::Half) {
260+
gelu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
261+
} else {
262+
gelu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
263+
queue, input, out
264+
);
265+
}
266+
return;
267+
}

0 commit comments

Comments
 (0)