Skip to content

Commit c4db4f8

Browse files
[SWISH] support Swish F32/F16 kernel (#85)
* [SWISH][FP16] first commit,add FP16 FP32 and fp16x8_pack kernel. * [SWISH][FP16] add README.md. * Update swish.cu * Update README.md * Update README.md --------- Co-authored-by: DefTruth <[email protected]>
1 parent a83ff8d commit c4db4f8

File tree

5 files changed

+394
-0
lines changed

5 files changed

+394
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
| ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
4949
| ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
5050
| ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️|
51+
| ✔️ [swish_f32](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️|
52+
| ✔️ [swish_f32x4](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️|
53+
| ✔️ [swish_f16](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️|
54+
| ✔️ [swish_f16x2](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️|
55+
| ✔️ [swish_f16x8](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️|
56+
| ✔️ [swish_f16x8_pack](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️⭐️|
5157
| ✔️ [embedding_f32](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️|
5258
| ✔️ [embedding_f32x4](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️|
5359
| ✔️ [embedding_f32x4_pack](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️|

swish/.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+

swish/README.md

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Swish
2+
3+
## 0x00 说明
4+
5+
包含以下内容:
6+
7+
- [X] swish_f32_kernel
8+
- [X] swish_f32x4_kernel(float4向量化版本)
9+
- [X] swish_f16_kernel(fp16版本)
10+
- [X] swish_f16x2_kernel(fp16向量化版本)
11+
- [X] swish_f16x8_kernel(fp16向量化版本)
12+
- [X] swish_f16x8_pack_kernel(fp16向量化,pack版本)
13+
- [X] PyTorch bindings
14+
15+
16+
## 测试
17+
18+
```bash
19+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
20+
export TORCH_CUDA_ARCH_LIST=Ada
21+
python3 swish.py
22+
```
23+
24+
输出:
25+
26+
```bash
27+
-------------------------------------------------------------------------------------
28+
S=1024, K=1024
29+
out_f32: ['0.46177661 ', '-0.10888041 '], time:0.01246500ms
30+
out_f32x4: ['0.46177661 ', '-0.10888041 '], time:0.01006508ms
31+
out_f32_th: ['0.46177667 ', '-0.10888041 '], time:0.03012419ms
32+
-------------------------------------------------------------------------------------
33+
out_f16: ['0.46191406 ', '-0.10894775 '], time:0.01299334ms
34+
out_f16x2: ['0.46191406 ', '-0.10894775 '], time:0.01036119ms
35+
out_f16x8: ['0.46191406 ', '-0.10894775 '], time:0.00979590ms
36+
out_f16x8pack: ['0.46191406 ', '-0.10894775 '], time:0.00972557ms
37+
out_f16_th: ['0.46191406 ', '-0.10888672 '], time:0.02423882ms
38+
-------------------------------------------------------------------------------------
39+
-------------------------------------------------------------------------------------
40+
S=1024, K=2048
41+
out_f32: ['-0.27797085 ', '0.71514565 '], time:0.01415992ms
42+
out_f32x4: ['-0.27797085 ', '0.71514565 '], time:0.01159716ms
43+
out_f32_th: ['-0.27797085 ', '0.71514559 '], time:0.02964258ms
44+
-------------------------------------------------------------------------------------
45+
out_f16: ['-0.27807617 ', '0.71582031 '], time:0.01473880ms
46+
out_f16x2: ['-0.27807617 ', '0.71582031 '], time:0.01404881ms
47+
out_f16x8: ['-0.27807617 ', '0.71582031 '], time:0.01127148ms
48+
out_f16x8pack: ['-0.27807617 ', '0.71582031 '], time:0.01101518ms
49+
out_f16_th: ['-0.27807617 ', '0.71533203 '], time:0.02657008ms
50+
-------------------------------------------------------------------------------------
51+
-------------------------------------------------------------------------------------
52+
S=1024, K=4096
53+
out_f32: ['0.29988611 ', '-0.2541697 '], time:0.01959276ms
54+
out_f32x4: ['0.29988611 ', '-0.2541697 '], time:0.01605868ms
55+
out_f32_th: ['0.29988611 ', '-0.25416973 '], time:0.03745818ms
56+
-------------------------------------------------------------------------------------
57+
out_f16: ['0.30004883 ', '-0.25415039 '], time:0.02078271ms
58+
out_f16x2: ['0.30004883 ', '-0.25415039 '], time:0.01729155ms
59+
out_f16x8: ['0.30004883 ', '-0.25415039 '], time:0.01489425ms
60+
out_f16x8pack: ['0.30004883 ', '-0.25415039 '], time:0.01351643ms
61+
out_f16_th: ['0.29980469 ', '-0.25415039 '], time:0.03149080ms
62+
-------------------------------------------------------------------------------------
63+
-------------------------------------------------------------------------------------
64+
S=2048, K=1024
65+
out_f32: ['-0.07777861 ', '-0.27842814 '], time:0.01640201ms
66+
out_f32x4: ['-0.07777861 ', '-0.27842814 '], time:0.01180029ms
67+
out_f32_th: ['-0.07777861 ', '-0.27842814 '], time:0.02952218ms
68+
-------------------------------------------------------------------------------------
69+
out_f16: ['-0.07775879 ', '-0.27856445 '], time:0.01758027ms
70+
out_f16x2: ['-0.07775879 ', '-0.27856445 '], time:0.01236153ms
71+
out_f16x8: ['-0.07775879 ', '-0.27856445 '], time:0.01109338ms
72+
out_f16x8pack: ['-0.07775879 ', '-0.27856445 '], time:0.01091790ms
73+
out_f16_th: ['-0.07775879 ', '-0.27856445 '], time:0.02657914ms
74+
-------------------------------------------------------------------------------------
75+
-------------------------------------------------------------------------------------
76+
S=2048, K=2048
77+
out_f32: ['-0.14754841 ', '-0.21989606 '], time:0.01957679ms
78+
out_f32x4: ['-0.14754841 ', '-0.21989606 '], time:0.01496792ms
79+
out_f32_th: ['-0.14754841 ', '-0.21989603 '], time:0.03751612ms
80+
-------------------------------------------------------------------------------------
81+
out_f16: ['-0.14758301 ', '-0.21984863 '], time:0.02085924ms
82+
out_f16x2: ['-0.14758301 ', '-0.21984863 '], time:0.01961517ms
83+
out_f16x8: ['-0.14758301 ', '-0.21984863 '], time:0.01386237ms
84+
out_f16x8pack: ['-0.14758301 ', '-0.21984863 '], time:0.01334929ms
85+
out_f16_th: ['-0.14758301 ', '-0.21984863 '], time:0.03151488ms
86+
-------------------------------------------------------------------------------------
87+
-------------------------------------------------------------------------------------
88+
S=2048, K=4096
89+
out_f32: ['1.07876182 ', '-0.27844051 '], time:0.03036070ms
90+
out_f32x4: ['1.07876182 ', '-0.27844051 '], time:0.02339220ms
91+
out_f32_th: ['1.07876182 ', '-0.27844048 '], time:0.05310464ms
92+
-------------------------------------------------------------------------------------
93+
out_f16: ['1.078125 ', '-0.27832031 '], time:0.03291988ms
94+
out_f16x2: ['1.078125 ', '-0.27832031 '], time:0.02590466ms
95+
out_f16x8: ['1.078125 ', '-0.27832031 '], time:0.02027988ms
96+
out_f16x8pack: ['1.078125 ', '-0.27832031 '], time:0.01811814ms
97+
out_f16_th: ['1.07910156 ', '-0.27832031 '], time:0.04083204ms
98+
-------------------------------------------------------------------------------------
99+
-------------------------------------------------------------------------------------
100+
S=4096, K=1024
101+
out_f32: ['0.31169948 ', '-0.18232882 '], time:0.02427077ms
102+
out_f32x4: ['0.31169948 ', '-0.18232882 '], time:0.01515222ms
103+
out_f32_th: ['0.31169948 ', '-0.18232881 '], time:0.03754425ms
104+
-------------------------------------------------------------------------------------
105+
out_f16: ['0.31152344 ', '-0.18237305 '], time:0.02679300ms
106+
out_f16x2: ['0.31152344 ', '-0.18237305 '], time:0.01617312ms
107+
out_f16x8: ['0.31152344 ', '-0.18237305 '], time:0.01357770ms
108+
out_f16x8pack: ['0.31152344 ', '-0.18237305 '], time:0.01324248ms
109+
out_f16_th: ['0.31152344 ', '-0.18225098 '], time:0.03149295ms
110+
-------------------------------------------------------------------------------------
111+
-------------------------------------------------------------------------------------
112+
S=4096, K=2048
113+
out_f32: ['1.5033319 ', '0.17473438 '], time:0.03030729ms
114+
out_f32x4: ['1.5033319 ', '0.17473438 '], time:0.02150083ms
115+
out_f32_th: ['1.5033319 ', '0.17473438 '], time:0.05257607ms
116+
-------------------------------------------------------------------------------------
117+
out_f16: ['1.50390625 ', '0.17468262 '], time:0.03289509ms
118+
out_f16x2: ['1.50390625 ', '0.17468262 '], time:0.03073120ms
119+
out_f16x8: ['1.50390625 ', '0.17468262 '], time:0.01862860ms
120+
out_f16x8pack: ['1.50390625 ', '0.17468262 '], time:0.01772857ms
121+
out_f16_th: ['1.50390625 ', '0.17468262 '], time:0.04082441ms
122+
-------------------------------------------------------------------------------------
123+
-------------------------------------------------------------------------------------
124+
S=4096, K=4096
125+
out_f32: ['-0.05288643 ', '-0.14218464 '], time:0.19254756ms
126+
out_f32x4: ['-0.05288643 ', '-0.14218464 '], time:0.19258785ms
127+
out_f32_th: ['-0.05288643 ', '-0.14218464 '], time:0.48660636ms
128+
-------------------------------------------------------------------------------------
129+
out_f16: ['-0.052948 ', '-0.14221191 '], time:0.05689216ms
130+
out_f16x2: ['-0.052948 ', '-0.14221191 '], time:0.04335928ms
131+
out_f16x8: ['-0.052948 ', '-0.14221191 '], time:0.03096652ms
132+
out_f16x8pack: ['-0.052948 ', '-0.14221191 '], time:0.02706647ms
133+
out_f16_th: ['-0.05288696 ', '-0.14221191 '], time:0.05971408ms
134+
-------------------------------------------------------------------------------------
135+
136+
```

swish/swish.cu

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <torch/types.h>
9+
#include <torch/extension.h>
10+
11+
#define WARP_SIZE 32
12+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
13+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
14+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
15+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
16+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
17+
18+
// -------------------------------------- FP32 --------------------------------------
19+
// Swish x: N, y: N y=x*sigmoid(x)
20+
__device__ __forceinline__ float swish(float x) {
21+
return x / (1.0f + expf(-x));
22+
}
23+
24+
__global__ void swish_f32_kernel(float* x, float* y, int N) {
25+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
26+
if (idx < N) y[idx] = swish(x[idx]);
27+
}
28+
29+
__global__ void swish_f32x4_kernel(float* x, float* y, int N) {
30+
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
31+
if (idx < N) {
32+
float4 reg_x = FLOAT4(x[idx]);
33+
float4 reg_y;
34+
reg_y.x = swish(reg_x.x);
35+
reg_y.y = swish(reg_x.y);
36+
reg_y.z = swish(reg_x.z);
37+
reg_y.w = swish(reg_x.w);
38+
FLOAT4(y[idx]) = reg_y;
39+
}
40+
}
41+
42+
// -------------------------------------- FP16 --------------------------------------
43+
__device__ __forceinline__ half swish_half(half x) {
44+
return __hmul(x, __hdiv(
45+
__float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x)))));
46+
}
47+
48+
__global__ void swish_f16_kernel(half* x, half* y, int N) {
49+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
50+
if (idx < N) y[idx] = swish_half(x[idx]);
51+
}
52+
53+
__global__ void swish_f16x2_kernel(half* x, half* y, int N) {
54+
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
55+
if (idx < N) {
56+
half2 reg_x = HALF2(x[idx]);
57+
half2 reg_y;
58+
reg_y.x = swish_half(reg_x.x);
59+
reg_y.y = swish_half(reg_x.y);
60+
HALF2(y[idx]) = reg_y;
61+
}
62+
}
63+
64+
__global__ void swish_f16x8_kernel(half* x, half* y, int N) {
65+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
66+
half2 reg_x_0 = HALF2(x[idx + 0]);
67+
half2 reg_x_1 = HALF2(x[idx + 2]);
68+
half2 reg_x_2 = HALF2(x[idx + 4]);
69+
half2 reg_x_3 = HALF2(x[idx + 6]);
70+
half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
71+
reg_y_0.x = swish_half(reg_x_0.x);
72+
reg_y_0.y = swish_half(reg_x_0.y);
73+
reg_y_1.x = swish_half(reg_x_1.x);
74+
reg_y_1.y = swish_half(reg_x_1.y);
75+
reg_y_2.x = swish_half(reg_x_2.x);
76+
reg_y_2.y = swish_half(reg_x_2.y);
77+
reg_y_3.x = swish_half(reg_x_3.x);
78+
reg_y_3.y = swish_half(reg_x_3.y);
79+
if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; }
80+
if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; }
81+
if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; }
82+
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
83+
}
84+
85+
__global__ void swish_f16x8_pack_kernel(half* x, half* y, int N) {
86+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
87+
half pack_x[8], pack_y[8];
88+
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]);
89+
90+
#pragma unroll
91+
for (int i = 0; i < 8; i++) {
92+
pack_y[i] = swish_half(pack_x[i]);
93+
}
94+
if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
95+
}
96+
97+
// --------------------- PyTorch bindings for custom kernel -----------------------
98+
#define STRINGFY(str) #str
99+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
100+
m.def(STRINGFY(func), &func, STRINGFY(func));
101+
102+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
103+
if(((T).options().dtype() != (th_type))) { \
104+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
105+
throw std::runtime_error("values must be "#th_type); \
106+
}
107+
108+
#define TORCH_BINDING_SWISH(packed_type, th_type, element_type, n_elements) \
109+
void swish_##packed_type(torch::Tensor x, torch::Tensor y) { \
110+
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
111+
CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
112+
const int ndim = x.dim(); \
113+
if (ndim != 2) { \
114+
int N = 1; \
115+
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
116+
dim3 block(256 / (n_elements)); \
117+
dim3 grid((N + 256 - 1) / 256); \
118+
swish_##packed_type##_kernel<<<grid, block>>>( \
119+
reinterpret_cast<element_type*>(x.data_ptr()), \
120+
reinterpret_cast<element_type*>(y.data_ptr()), N); \
121+
} else { \
122+
const int S = x.size(0); \
123+
const int K = x.size(1); \
124+
const int N = S * K; \
125+
if ((K/(n_elements)) <= 1024) { \
126+
dim3 block(K/(n_elements)); \
127+
dim3 grid(S); \
128+
swish_##packed_type##_kernel<<<grid, block>>>( \
129+
reinterpret_cast<element_type*>(x.data_ptr()), \
130+
reinterpret_cast<element_type*>(y.data_ptr()), N); \
131+
} else { \
132+
int N = 1; \
133+
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
134+
dim3 block(256 / (n_elements)); \
135+
dim3 grid((N + 256 - 1) / 256); \
136+
swish_##packed_type##_kernel<<<grid, block>>>( \
137+
reinterpret_cast<element_type*>(x.data_ptr()), \
138+
reinterpret_cast<element_type*>(y.data_ptr()), N); \
139+
} \
140+
} \
141+
}
142+
143+
TORCH_BINDING_SWISH(f32, torch::kFloat32, float, 1)
144+
TORCH_BINDING_SWISH(f32x4, torch::kFloat32, float, 4)
145+
TORCH_BINDING_SWISH(f16, torch::kHalf, half, 1)
146+
TORCH_BINDING_SWISH(f16x2, torch::kHalf, half, 2)
147+
TORCH_BINDING_SWISH(f16x8, torch::kHalf, half, 8)
148+
TORCH_BINDING_SWISH(f16x8_pack, torch::kHalf, half, 8)
149+
150+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
151+
TORCH_BINDING_COMMON_EXTENSION(swish_f32)
152+
TORCH_BINDING_COMMON_EXTENSION(swish_f32x4)
153+
TORCH_BINDING_COMMON_EXTENSION(swish_f16)
154+
TORCH_BINDING_COMMON_EXTENSION(swish_f16x2)
155+
TORCH_BINDING_COMMON_EXTENSION(swish_f16x8)
156+
TORCH_BINDING_COMMON_EXTENSION(swish_f16x8_pack)
157+
}

0 commit comments

Comments
 (0)