1+ #include < ATen/ATen.h>
2+ #include < ATen/Parallel.h>
3+ #include < c10/xpu/XPUStream.h>
4+ #include < torch/all.h>
5+ #include < ATen/OpMathType.h>
6+
7+ #include < cmath>
8+ #include < cstdint>
9+ #include < iostream>
10+ #include < sycl/sycl.hpp>
11+ #include < vector>
12+ #include " Utils.h"
13+
14+ #include " SYCLHelpers.h"
15+
16+ #define DPCPP_CONSTANT __attribute__ ((opencl_constant))
17+
18+ #define DPCPP_KER_STRING (var, str ) static const DPCPP_CONSTANT char var[] = str;
19+ #define DPCPP_KER_PRINTF sycl::ext::oneapi::experimental::printf
20+
21+ #define DPCPP_K_PRINT (fmt_str, ...) \
22+ { \
23+ DPCPP_KER_STRING (fmt_var, fmt_str); \
24+ DPCPP_KER_PRINTF (fmt_var, ##__VA_ARGS__); \
25+ }
26+
27+ template <typename scalar_t , int vec_size>
28+ struct alignas (sizeof (scalar_t ) * vec_size) aligned_vector_loop {
29+ scalar_t val[vec_size];
30+
31+ scalar_t & operator [](int index) {
32+ return val[index];
33+ }
34+
35+ scalar_t const & operator [](int index) const {
36+ return val[index];
37+ }
38+ };
39+
40+ template <typename scalar_t , typename accscalar_t >
41+ struct silu_mul_dpcpp_functor {
42+ scalar_t operator ()(scalar_t a, scalar_t b) const {
43+ return (accscalar_t (a)) / (1 .0f + expf (accscalar_t (-a))) * accscalar_t (b);
44+ }
45+ };
46+
47+ template <typename scalar_t , typename accscalar_t >
48+ struct gelu_tanh_mul_dpcpp_functor {
49+ scalar_t operator ()(scalar_t a, scalar_t b) const {
50+ const accscalar_t kBeta = M_SQRT2 * M_2_SQRTPI * accscalar_t (0.5 );
51+ const accscalar_t kKappa = 0.044715 ;
52+ auto x_cube = accscalar_t (a) * accscalar_t (a) * accscalar_t (a);
53+ auto inner = kBeta * (accscalar_t (a) + kKappa * x_cube);
54+ return (accscalar_t (0.5 ) * accscalar_t (a) * (accscalar_t (1 ) + std::tanh (accscalar_t (inner)))) * accscalar_t (b);
55+ }
56+ };
57+
58+ template <typename scalar_t , typename accscalar_t >
59+ struct gelu_erf_mul_dpcpp_functor {
60+ scalar_t operator ()(scalar_t a, scalar_t b) const {
61+ return (accscalar_t (a) * accscalar_t (0.5 ) * (accscalar_t (1 ) + ::erf (accscalar_t (a) * accscalar_t (M_SQRT1_2)))) * accscalar_t (b);
62+ }
63+ };
64+
65+ template <typename scalar_t , typename func_t , int N>
66+ struct op_and_mul_functor {
67+ void operator ()(sycl::nd_item<1 > item) const {
68+ using accscalar_t = at::opmath_type<scalar_t >;
69+ int64_t offset = item.get_local_linear_id ();
70+ int64_t step = item.get_local_range (0 );
71+ int64_t token_id = item.get_group (0 );
72+ func_t fn;
73+ int64_t bound = dim / N;
74+ for (int64_t i = offset; i < bound; i += step) {
75+ auto unary_val = reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(input_ptr)[token_id * bound * 2 + i];
76+ auto mul_val = reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(input_ptr)[token_id * bound * 2 + i + bound];
77+ #pragma unroll
78+ for (int i = 0 ; i < N; ++i) {
79+ auto a = unary_val[i], b = mul_val[i];
80+ unary_val[i] = fn (unary_val[i], mul_val[i]);
81+ }
82+ reinterpret_cast <aligned_vector_loop<scalar_t , N>*>(output_ptr)[token_id * bound + i] = unary_val;
83+
84+ }
85+ }
86+
87+ scalar_t * input_ptr;
88+ scalar_t * output_ptr;
89+ int64_t num_;
90+ int64_t dim;
91+ };
92+
93+ #define VEC_LAUNCH (KERNEL, N ) \
94+ case N: { \
95+ op_and_mul_functor<T_to, KERNEL<T_to, accscalar_t >, N> kfn = { \
96+ .input_ptr = _input, \
97+ .output_ptr = _out, \
98+ .num_ = numel, \
99+ .dim = dim \
100+ }; \
101+ sycl_kernel_submit (num_group*wg_size, wg_size, q, kfn); \
102+ break ; \
103+ } \
104+
105+ template <typename T = float >
106+ void get_config (
107+ const Tensor& input,
108+ const Tensor& out,
109+ int64_t & numel,
110+ int64_t & dim,
111+ int64_t & wg_size,
112+ int64_t & num_group,
113+ int & vec_size) {
114+ auto dev_id = torch_ipex::xpu::dpcpp::dpcppGetDeviceIdOfCurrentQueue ();
115+ int64_t max_wg_size = torch_ipex::xpu::dpcpp::dpcppMaxWorkGroupSize (dev_id);
116+ numel = out.numel ();
117+ dim = out.size (-1 );
118+ int64_t tokens = numel/dim;
119+ wg_size = std::min (dim, max_wg_size);
120+ num_group = tokens;
121+
122+ vec_size = sizeof (float ) * 4 / sizeof (T);
123+ while ((vec_size >> 1 ) * wg_size >= dim) {
124+ vec_size = vec_size >> 1 ;
125+ }
126+ if (dim % vec_size != 0 )
127+ vec_size = 1 ;
128+ }
129+
130+ template <typename T_to = float , typename T_from = float >
131+ void silu_and_mul_sycl (
132+ sycl::queue& q,
133+ Tensor& input,
134+ Tensor& out) {
135+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
136+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
137+
138+ int64_t numel;
139+ int64_t dim;
140+ int64_t wg_size;
141+ int64_t num_group;
142+ int vec_size;
143+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
144+
145+ using accscalar_t = at::opmath_type<T_to>;
146+ switch (vec_size) {
147+ VEC_LAUNCH (silu_mul_dpcpp_functor, 1 );
148+ VEC_LAUNCH (silu_mul_dpcpp_functor, 2 );
149+ VEC_LAUNCH (silu_mul_dpcpp_functor, 4 );
150+ VEC_LAUNCH (silu_mul_dpcpp_functor, 8 );
151+ VEC_LAUNCH (silu_mul_dpcpp_functor, 16 );
152+ default :
153+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
154+ }
155+
156+ return ;
157+ }
158+
159+ void silu_and_mul (Tensor& out, Tensor& input) {
160+ input = input.contiguous ();
161+ out = out.contiguous ();
162+
163+ auto stream = at::xpu::getCurrentXPUStream ();
164+ auto queue = stream.queue ();
165+
166+ if (input.scalar_type () == at::ScalarType::Half) {
167+ silu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
168+ } else {
169+ silu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
170+ queue, input, out
171+ );
172+ }
173+ return ;
174+ }
175+
176+ template <typename T_to = float , typename T_from = float >
177+ void gelu_tanh_and_mul_sycl (
178+ sycl::queue& q,
179+ Tensor& input,
180+ Tensor& out) {
181+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
182+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
183+
184+ int64_t numel;
185+ int64_t dim;
186+ int64_t wg_size;
187+ int64_t num_group;
188+ int vec_size;
189+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
190+
191+ using accscalar_t = at::opmath_type<T_to>;
192+ switch (vec_size) {
193+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 1 );
194+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 2 );
195+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 4 );
196+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 8 );
197+ VEC_LAUNCH (gelu_tanh_mul_dpcpp_functor, 16 );
198+ default :
199+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
200+ }
201+
202+ return ;
203+ }
204+
205+ void gelu_tanh_and_mul (Tensor& out, Tensor& input) {
206+ input = input.contiguous ();
207+ out = out.contiguous ();
208+
209+ auto stream = at::xpu::getCurrentXPUStream ();
210+ auto queue = stream.queue ();
211+
212+ if (input.scalar_type () == at::ScalarType::Half) {
213+ gelu_tanh_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
214+ } else {
215+ gelu_tanh_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
216+ queue, input, out
217+ );
218+ }
219+ return ;
220+ }
221+
222+ template <typename T_to = float , typename T_from = float >
223+ void gelu_and_mul_sycl (
224+ sycl::queue& q,
225+ Tensor& input,
226+ Tensor& out) {
227+ auto _input = reinterpret_cast <T_to*>(input.data_ptr <T_from>());
228+ auto _out = reinterpret_cast <T_to*>(out.data_ptr <T_from>());
229+
230+ int64_t numel;
231+ int64_t dim;
232+ int64_t wg_size;
233+ int64_t num_group;
234+ int vec_size;
235+ get_config<T_to>(input, out, numel, dim, wg_size, num_group, vec_size);
236+
237+ using accscalar_t = at::opmath_type<T_to>;
238+ switch (vec_size) {
239+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 1 );
240+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 2 );
241+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 4 );
242+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 8 );
243+ VEC_LAUNCH (gelu_erf_mul_dpcpp_functor, 16 );
244+ default :
245+ TORCH_CHECK (false , " Unsupported vector size: " , vec_size);
246+ }
247+
248+ return ;
249+ }
250+
251+
252+ void gelu_and_mul (Tensor& out, Tensor& input) {
253+ input = input.contiguous ();
254+ out = out.contiguous ();
255+
256+ auto stream = at::xpu::getCurrentXPUStream ();
257+ auto queue = stream.queue ();
258+
259+ if (input.scalar_type () == at::ScalarType::Half) {
260+ gelu_and_mul_sycl<sycl::half, at::Half>(queue, input, out);
261+ } else {
262+ gelu_and_mul_sycl<sycl::ext::oneapi::bfloat16, at::BFloat16>(
263+ queue, input, out
264+ );
265+ }
266+ return ;
267+ }
0 commit comments