Skip to content

Commit d3f063d

Browse files
authored
elementwise kernels.
Add elementwise kernels.
2 parents e36af02 + 3b52c17 commit d3f063d

File tree

6 files changed

+444
-10
lines changed

6 files changed

+444
-10
lines changed

include/sgl_kernel_ops.h

100755100644
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ limitations under the License.
2222
#include <torch/torch.h>
2323

2424
#include <sycl/sycl.hpp>
25-
2625
#include <tuple>
2726
#include <vector>
2827

@@ -123,9 +122,9 @@ void sgl_fused_add_rmsnorm(
123122
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
124123
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
125124
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
126-
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
127-
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
128-
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t sycl_stream);
125+
void silu_and_mul(at::Tensor& out, at::Tensor& input);
126+
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input);
127+
void gelu_and_mul(at::Tensor& out, at::Tensor& input);
129128
void apply_rope_pos_ids_cos_sin_cache(
130129
at::Tensor q,
131130
at::Tensor k,

python/sgl_kernel/elementwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
179179
device=input.device,
180180
dtype=input.dtype,
181181
)
182-
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
182+
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
183183
return out
184184

185185

@@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
194194
device=input.device,
195195
dtype=input.dtype,
196196
)
197-
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
197+
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
198198
return out
199199

200200

@@ -209,7 +209,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
209209
device=input.device,
210210
dtype=input.dtype,
211211
)
212-
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
212+
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
213213
return out
214214

215215

src/sycl/TripleOps.cpp

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/OpMathType.h>
3+
#include <ATen/Parallel.h>
4+
#include <c10/xpu/XPUStream.h>
5+
#include <torch/all.h>
6+
7+
#include <cmath>
8+
#include <cstdint>
9+
#include <iostream>
10+
#include <sycl/sycl.hpp>
11+
#include <vector>
12+
13+
#include "SYCLHelpers.h"
14+
#include "Utils.h"
15+
16+
#define DPCPP_CONSTANT __attribute__((opencl_constant))
17+
18+
#define DPCPP_KER_STRING(var, str) static const DPCPP_CONSTANT char var[] = str;
19+
#define DPCPP_KER_PRINTF sycl::ext::oneapi::experimental::printf
20+
21+
#define DPCPP_K_PRINT(fmt_str, ...) \
22+
{ \
23+
DPCPP_KER_STRING(fmt_var, fmt_str); \
24+
DPCPP_KER_PRINTF(fmt_var, ##__VA_ARGS__); \
25+
}
26+
27+
template <typename scalar_t, int vec_size>
28+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector_loop {
29+
scalar_t val[vec_size];
30+
31+
scalar_t& operator[](int index) {
32+
return val[index];
33+
}
34+
35+
scalar_t const& operator[](int index) const {
36+
return val[index];
37+
}
38+
};
39+
40+
template <typename scalar_t, typename accscalar_t>
41+
struct silu_mul_dpcpp_functor {
42+
scalar_t operator()(scalar_t a, scalar_t b) const {
43+
return (accscalar_t(a)) / (1.0f + expf(accscalar_t(-a))) * accscalar_t(b);
44+
}
45+
};
46+
47+
template <typename scalar_t, typename accscalar_t>
48+
struct gelu_tanh_mul_dpcpp_functor {
49+
scalar_t operator()(scalar_t a, scalar_t b) const {
50+
const accscalar_t kBeta = M_SQRT2 * M_2_SQRTPI * accscalar_t(0.5);
51+
const accscalar_t kKappa = 0.044715;
52+
auto x_cube = accscalar_t(a) * accscalar_t(a) * accscalar_t(a);
53+
auto inner = kBeta * (accscalar_t(a) + kKappa * x_cube);
54+
return (accscalar_t(0.5) * accscalar_t(a) * (accscalar_t(1) + std::tanh(accscalar_t(inner)))) * accscalar_t(b);
55+
}
56+
};
57+
58+
template <typename scalar_t, typename accscalar_t>
59+
struct gelu_erf_mul_dpcpp_functor {
60+
scalar_t operator()(scalar_t a, scalar_t b) const {
61+
return (accscalar_t(a) * accscalar_t(0.5) * (accscalar_t(1) + ::erf(accscalar_t(a) * accscalar_t(M_SQRT1_2)))) *
62+
accscalar_t(b);
63+
}
64+
};
65+
66+
template <typename scalar_t, typename func_t, int N>
67+
struct op_and_mul_functor {
68+
void operator()(sycl::nd_item<1> item) const {
69+
using accscalar_t = at::opmath_type<scalar_t>;
70+
int64_t offset = item.get_local_linear_id();
71+
int64_t step = item.get_local_range(0);
72+
int64_t token_id = item.get_group(0);
73+
func_t fn;
74+
int64_t bound = dim / N;
75+
for (int64_t i = offset; i < bound; i += step) {
76+
auto unary_val = reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(input_ptr)[token_id * bound * 2 + i];
77+
auto mul_val = reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(input_ptr)[token_id * bound * 2 + i + bound];
78+
#pragma unroll
79+
for (int i = 0; i < N; ++i) {
80+
auto a = unary_val[i], b = mul_val[i];
81+
unary_val[i] = fn(unary_val[i], mul_val[i]);
82+
}
83+
reinterpret_cast<aligned_vector_loop<scalar_t, N>*>(output_ptr)[token_id * bound + i] = unary_val;
84+
}
85+
}
86+
87+
scalar_t* input_ptr;
88+
scalar_t* output_ptr;
89+
int64_t num_;
90+
int64_t dim;
91+
};
92+
93+
#define VEC_LAUNCH(KERNEL, N) \
94+
case N: { \
95+
op_and_mul_functor<T_to, KERNEL<T_to, accscalar_t>, N> kfn = { \
96+
.input_ptr = _input, .output_ptr = _out, .num_ = numel, .dim = dim}; \
97+
sycl_kernel_submit(num_group* wg_size, wg_size, q, kfn); \
98+
break; \
99+
}
100+
101+
template <typename T = float>
102+
void get_config(
103+
const Tensor& input,
104+
const Tensor& out,
105+
int64_t& numel,
106+
int64_t& dim,
107+
int64_t& wg_size,
108+
int64_t& num_group,
109+
int& vec_size) {
110+
auto dev_id = dpcppGetDeviceIdOfCurrentQueue();
111+
int64_t max_wg_size = dpcppMaxWorkGroupSize(dev_id);
112+
numel = out.numel();
113+
dim = out.size(-1);
114+
int64_t tokens = numel / dim;
115+
wg_size = std::min(dim, max_wg_size);
116+
num_group = tokens;
117+
118+
vec_size = sizeof(float) * 4 / sizeof(T);
119+
while ((vec_size >> 1) * wg_size >= dim) {
120+
vec_size = vec_size >> 1;
121+
}
122+
if (dim % vec_size != 0) vec_size = 1;
123+
}
124+
125+
template <typename T_to = float, typename T_from = float>
126+
void silu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
127+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
128+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
129+
130+
int64_t numel;
131+
int64_t dim;
132+
int64_t wg_size;
133+
int64_t num_group;
134+
int vec_size;
135+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
136+
137+
using accscalar_t = at::opmath_type<T_to>;
138+
switch (vec_size) {
139+
VEC_LAUNCH(silu_mul_dpcpp_functor, 1);
140+
VEC_LAUNCH(silu_mul_dpcpp_functor, 2);
141+
VEC_LAUNCH(silu_mul_dpcpp_functor, 4);
142+
VEC_LAUNCH(silu_mul_dpcpp_functor, 8);
143+
VEC_LAUNCH(silu_mul_dpcpp_functor, 16);
144+
default:
145+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
146+
}
147+
148+
return;
149+
}
150+
151+
void silu_and_mul(Tensor& out, Tensor& input) {
152+
input = input.contiguous();
153+
out = out.contiguous();
154+
155+
auto stream = at::xpu::getCurrentXPUStream();
156+
auto queue = stream.queue();
157+
158+
if (input.scalar_type() == at::ScalarType::Half) {
159+
silu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
160+
} else {
161+
silu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
162+
}
163+
return;
164+
}
165+
166+
template <typename T_to = float, typename T_from = float>
167+
void gelu_tanh_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
168+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
169+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
170+
171+
int64_t numel;
172+
int64_t dim;
173+
int64_t wg_size;
174+
int64_t num_group;
175+
int vec_size;
176+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
177+
178+
using accscalar_t = at::opmath_type<T_to>;
179+
switch (vec_size) {
180+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 1);
181+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 2);
182+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 4);
183+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 8);
184+
VEC_LAUNCH(gelu_tanh_mul_dpcpp_functor, 16);
185+
default:
186+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
187+
}
188+
189+
return;
190+
}
191+
192+
void gelu_tanh_and_mul(Tensor& out, Tensor& input) {
193+
input = input.contiguous();
194+
out = out.contiguous();
195+
196+
auto stream = at::xpu::getCurrentXPUStream();
197+
auto queue = stream.queue();
198+
199+
if (input.scalar_type() == at::ScalarType::Half) {
200+
gelu_tanh_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
201+
} else {
202+
gelu_tanh_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
203+
}
204+
return;
205+
}
206+
207+
template <typename T_to = float, typename T_from = float>
208+
void gelu_and_mul_sycl(sycl::queue& q, Tensor& input, Tensor& out) {
209+
auto _input = reinterpret_cast<T_to*>(input.data_ptr<T_from>());
210+
auto _out = reinterpret_cast<T_to*>(out.data_ptr<T_from>());
211+
212+
int64_t numel;
213+
int64_t dim;
214+
int64_t wg_size;
215+
int64_t num_group;
216+
int vec_size;
217+
get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
218+
219+
using accscalar_t = at::opmath_type<T_to>;
220+
switch (vec_size) {
221+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 1);
222+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 2);
223+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 4);
224+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 8);
225+
VEC_LAUNCH(gelu_erf_mul_dpcpp_functor, 16);
226+
default:
227+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size);
228+
}
229+
230+
return;
231+
}
232+
233+
void gelu_and_mul(Tensor& out, Tensor& input) {
234+
input = input.contiguous();
235+
out = out.contiguous();
236+
237+
auto stream = at::xpu::getCurrentXPUStream();
238+
auto queue = stream.queue();
239+
240+
if (input.scalar_type() == at::ScalarType::Half) {
241+
gelu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
242+
} else {
243+
gelu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
244+
}
245+
return;
246+
}

0 commit comments

Comments
 (0)