1+ #include < ATen/ATen.h>
2+ #include < ATen/OpMathType.h>
3+ #include < ATen/Parallel.h>
4+ #include < c10/xpu/XPUStream.h>
5+ #include < torch/all.h>
6+
7+ #include < cmath>
8+ #include < cstdint>
9+ #include < iostream>
10+ #include < sycl/sycl.hpp>
11+ #include < vector>
12+
13+ #include " SYCLHelpers.h"
14+ #include " Utils.h"
15+
16+ #define DPCPP_CONSTANT __attribute__ ((opencl_constant))
17+
18+ #define DPCPP_KER_STRING (var, str ) static const DPCPP_CONSTANT char var[] = str;
19+ #define DPCPP_KER_PRINTF sycl::ext::oneapi::experimental::printf
20+
21+ #define DPCPP_K_PRINT (fmt_str, ...) \
22+ { \
23+ DPCPP_KER_STRING (fmt_var, fmt_str); \
24+ DPCPP_KER_PRINTF (fmt_var, ##__VA_ARGS__); \
25+ }
26+
27+ template <typename scalar_t , int vec_size>
28+ struct alignas (sizeof (scalar_t ) * vec_size) aligned_vector_loop {
29+ scalar_t val[vec_size];
30+
31+ scalar_t & operator [](int index) {
32+ return val[index];
33+ }
34+
35+ scalar_t const & operator [](int index) const {
36+ return val[index];
37+ }
38+ };
39+
40+ template <typename scalar_t , typename accscalar_t >
41+ struct silu_mul_dpcpp_functor {
42+ scalar_t operator ()(scalar_t a, scalar_t b) const {
43+ return (accscalar_t (a)) / (1 .0f + expf (accscalar_t (-a))) * accscalar_t (b);
44+ }
45+ };
46+
47+ template <typename scalar_t , typename accscalar_t >
48+ struct gelu_tanh_mul_dpcpp_functor {
49+ scalar_t operator ()(scalar_t a, scalar_t b) const {
50+ const accscalar_t kBeta = M_SQRT2 * M_2_SQRTPI * accscalar_t (0.5 );
51+ const accscalar_t kKappa = 0.044715 ;
52+ auto x_cube = accscalar_t (a) * accscalar_t (a) * accscalar_t (a);
53+ auto inner = kBeta * (accscalar_t (a) + kKappa * x_cube);
54+ return (accscalar_t (0.5 ) * accscalar_t (a) * (accscalar_t (1 ) + std::tanh (accscalar_t (inner)))) * accscalar_t (b);
55+ }
56+ };
57+
58+ template <typename scalar_t , typename accscalar_t >
59+ struct gelu_erf_mul_dpcpp_functor {
60+ scalar_t operator ()(scalar_t a, scalar_t b) const {
61+ return (accscalar_t (a) * accscalar_t (0.5 ) * (accscalar_t (1 ) + ::erf (accscalar_t (a) * accscalar_t (M_SQRT1_2)))) *
62+ accscalar_t (b);
63+ }
64+ };
65+
66+ template <typename scalar_t , typename func_t , int N>
67+ struct op_and_mul_functor {
68+ void operator ()(sycl::nd_item<1 > item) const {
69+ using accscalar_t = at::opmath_type<scalar_t >;
70+ int64_t offset = item.get_local_linear_id ();
71+ int64_t step = item.get_local_range (0 );
72+ int64_t token_id = item.get_group (0 );
73+ func_t fn;
74+ int64_t bound = dim / N;
75+ for (int64_t i = offset; i < bound; i += step) {
76+ auto unary_val = reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(input_ptr)[token_id * bound * 2 + i];
77+ auto mul_val = reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(input_ptr)[token_id * bound * 2 + i + bound];
78+ #pragma unroll
79+ for (int i = 0 ; i < N; ++i) {
80+ auto a = unary_val[i], b = mul_val[i];
81+ unary_val[i] = fn (unary_val[i], mul_val[i]);
82+ }
83+ reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(output_ptr)[token_id * bound + i] = unary_val;
84+ }
85+ }
86+
87+ scalar_t * input_ptr;
88+ scalar_t * output_ptr;
89+ int64_t num_;
90+ int64_t dim;
91+ };
92+
93+ #define VEC_LAUNCH (KERNEL, N ) \
94+ case N: { \
95+ op_and_mul_functor<T_to, KERNEL<T_to, accscalar_t >, N> kfn = { \
96+ .input_ptr = _input, .output_ptr = _out, .num_ = numel, .dim = dim}; \
97+ sycl_kernel_submit (num_group* wg_size, wg_size, q, kfn); \
98+ break ; \
99+ }
100+
101+ template <typename T = float >
102+ void get_config (
103+ const Tensor& input,
104+ const Tensor& out,
105+ int64_t & numel,
106+ int64_t & dim,
107+ int64_t & wg_size,
108+ int64_t & num_group,
109+ int & vec_size) {
110+ auto dev_id = dpcppGetDeviceIdOfCurrentQueue ();
111+ int64_t max_wg_size = dpcppMaxWorkGroupSize (dev_id);
112+ numel = out.numel ();
113+ dim = out.size (-1 );
114+ int64_t tokens = numel / dim;
115+ wg_size = std::min (dim, max_wg_size);
116+ num_group = tokens;
117+
118+ vec_size = sizeof (float ) * 4 / sizeof (T);
119+ while ((vec_size >> 1 ) * wg_size >= dim) {
120+ vec_size = vec_size >> 1 ;
121+ }
122+ if (dim % vec_size != 0 ) vec_size = 1 ;
123+ }
124+
125+ template <typename T_to = float , typename T_from = float >
126+ void silu_and_mul_sycl (sycl::queue& q, Tensor& input, Tensor& out) {
127+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
128+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
129+
130+ int64_t numel;
131+ int64_t dim;
132+ int64_t wg_size;
133+ int64_t num_group;
134+ int vec_size;
135+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
136+
137+ using accscalar_t = at::opmath_type<T_to>;
138+ switch (vec_size) {
139+ VEC_LAUNCH (silu_mul_dpcpp_functor, 1 );
140+ VEC_LAUNCH (silu_mul_dpcpp_functor, 2 );
141+ VEC_LAUNCH (silu_mul_dpcpp_functor, 4 );
142+ VEC_LAUNCH (silu_mul_dpcpp_functor, 8 );
143+ VEC_LAUNCH (silu_mul_dpcpp_functor, 16 );
144+ default :
145+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
146+ }
147+
148+ return ;
149+ }
150+
151+ void silu_and_mul (Tensor& out, Tensor& input) {
152+ input = input.contiguous ();
153+ out = out.contiguous ();
154+
155+ auto stream = at::xpu::getCurrentXPUStream ();
156+ auto queue = stream.queue ();
157+
158+ if (input.scalar_type () == at::ScalarType::Half) {
159+ silu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
160+ } else {
161+ silu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
162+ }
163+ return ;
164+ }
165+
166+ template <typename T_to = float , typename T_from = float >
167+ void gelu_tanh_and_mul_sycl (sycl::queue& q, Tensor& input, Tensor& out) {
168+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
169+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
170+
171+ int64_t numel;
172+ int64_t dim;
173+ int64_t wg_size;
174+ int64_t num_group;
175+ int vec_size;
176+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
177+
178+ using accscalar_t = at::opmath_type<T_to>;
179+ switch (vec_size) {
180+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 1 );
181+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 2 );
182+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 4 );
183+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 8 );
184+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 16 );
185+ default :
186+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
187+ }
188+
189+ return ;
190+ }
191+
192+ void gelu_tanh_and_mul (Tensor& out, Tensor& input) {
193+ input = input.contiguous ();
194+ out = out.contiguous ();
195+
196+ auto stream = at::xpu::getCurrentXPUStream ();
197+ auto queue = stream.queue ();
198+
199+ if (input.scalar_type () == at::ScalarType::Half) {
200+ gelu_tanh_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
201+ } else {
202+ gelu_tanh_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
203+ }
204+ return ;
205+ }
206+
207+ template <typename T_to = float , typename T_from = float >
208+ void gelu_and_mul_sycl (sycl::queue& q, Tensor& input, Tensor& out) {
209+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
210+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
211+
212+ int64_t numel;
213+ int64_t dim;
214+ int64_t wg_size;
215+ int64_t num_group;
216+ int vec_size;
217+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
218+
219+ using accscalar_t = at::opmath_type<T_to>;
220+ switch (vec_size) {
221+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 1 );
222+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 2 );
223+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 4 );
224+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 8 );
225+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 16 );
226+ default :
227+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
228+ }
229+
230+ return ;
231+ }
232+
233+ void gelu_and_mul (Tensor& out, Tensor& input) {
234+ input = input.contiguous ();
235+ out = out.contiguous ();
236+
237+ auto stream = at::xpu::getCurrentXPUStream ();
238+ auto queue = stream.queue ();
239+
240+ if (input.scalar_type () == at::ScalarType::Half) {
241+ gelu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
242+ } else {
243+ gelu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(queue, input, out);
244+ }
245+ return ;
246+ }
0 commit comments