Skip to content

Commit 1eae888

Browse files
bear-zdDefTruth
andauthored
[GELU] Add f32/x4, f16/x2/x8/x8 pack kernel(#66)
* [GELU] Add f32/x4, f16/x2/x8/x8pack kernel. * Update README.md * Update gelu.cu * Update gelu.py * Update README.md --------- Co-authored-by: DefTruth <[email protected]>
1 parent 3b56750 commit 1eae888

File tree

5 files changed

+482
-0
lines changed

5 files changed

+482
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4343
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4444
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
45+
| ✔️ [gelu_f32](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️|
46+
| ✔️ [gelu_f32x4](./gelu/gelu.cu)|f32|/|[link](./gelu/)|⭐️|
47+
| ✔️ [gelu_f16](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
48+
| ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
49+
| ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️|
50+
| ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️|
4551
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
4652
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
4753
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|

gelu/.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+

gelu/README.md

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# GELU
2+
3+
## 0x00 说明
4+
5+
包含以下内容:
6+
7+
- [X] gelu_f32_kernel
8+
- [X] gelu_f32x4_kernel(float4向量化版本)
9+
- [X] gelu_f16_kernel
10+
- [X] gelu_f16x2_kernel(half2向量化)
11+
- [X] gelu_f16x8_kernel(unpack版本)
12+
- [X] gelu_f16x8_pack_kernel(pack版本)
13+
- [X] PyTorch bindings
14+
15+
16+
## 测试
17+
18+
对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可:
19+
20+
```c++
21+
y[idx] = HALF_GELU_OPS(__half2float(v)); // line 96
22+
reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); // line 109 , line 110
23+
reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y));
24+
```
25+
测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了):
26+
```bash
27+
-------------------------------------------------------------------------------------
28+
S=2048, K=4096
29+
out_f32: [-0.08196318, -0.1613517], time:0.13425708ms
30+
out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms
31+
out_f32_th: [-0.08196313, -0.1613517], time:0.08195782ms
32+
-------------------------------------------------------------------------------------
33+
out_f16: [-0.08197021, -0.16137695], time:0.12120271ms
34+
out_f16x2: [-0.08197021, -0.16137695], time:0.12122369ms
35+
out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms
36+
out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms
37+
out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms
38+
-------------------------------------------------------------------------------------
39+
```
40+
相关参考:
41+
- [pytorch-c10-BFloat16.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h)
42+
- [math ptx](https://github.com/pavanky/math_ptx)
43+
44+
此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。
45+
46+
```bash
47+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
48+
export TORCH_CUDA_ARCH_LIST=Ada
49+
python3 gelu.py
50+
```
51+
52+
输出(不做类型转换导致half误差):
53+
54+
```bash
55+
-------------------------------------------------------------------------------------
56+
S=1024, K=1024
57+
out_f32: [-0.13358943, -0.06881647], time:0.01621890ms
58+
out_f32x4: [-0.13358943, -0.06881647], time:0.01278400ms
59+
out_f32_th: [-0.13358943, -0.06881647], time:0.00897789ms
60+
-------------------------------------------------------------------------------------
61+
out_f16: [-0.13378906, -0.06884766], time:0.00663781ms
62+
out_f16x2: [-0.13378906, -0.06884766], time:0.00366306ms
63+
out_f16x8: [-0.13378906, -0.06884766], time:0.00343323ms
64+
out_f16x8pack: [-0.13378906, -0.06884766], time:0.00331473ms
65+
out_f16_th: [-0.13354492, -0.06884766], time:0.00907278ms
66+
-------------------------------------------------------------------------------------
67+
-------------------------------------------------------------------------------------
68+
S=1024, K=2048
69+
out_f32: [1.38783729, -0.06707606], time:0.02223682ms
70+
out_f32x4: [1.38783729, -0.06707606], time:0.02367806ms
71+
out_f32_th: [1.38783729, -0.06707606], time:0.00959325ms
72+
-------------------------------------------------------------------------------------
73+
out_f16: [1.38769531, -0.06713867], time:0.00834370ms
74+
out_f16x2: [1.38769531, -0.06713867], time:0.00784707ms
75+
out_f16x8: [1.38769531, -0.06713867], time:0.00499964ms
76+
out_f16x8pack: [1.38769531, -0.06713867], time:0.00461078ms
77+
out_f16_th: [1.38769531, -0.06707764], time:0.00895357ms
78+
-------------------------------------------------------------------------------------
79+
-------------------------------------------------------------------------------------
80+
S=1024, K=4096
81+
out_f32: [0.47386399, 0.05760021], time:0.04273629ms
82+
out_f32x4: [0.47386399, 0.05760021], time:0.05011940ms
83+
out_f32_th: [0.47386405, 0.05760022], time:0.00933146ms
84+
-------------------------------------------------------------------------------------
85+
out_f16: [0.47387695, 0.05761719], time:0.01495123ms
86+
out_f16x2: [0.47387695, 0.05761719], time:0.01039743ms
87+
out_f16x8: [0.47387695, 0.05761719], time:0.00936055ms
88+
out_f16x8pack: [0.47387695, 0.05761719], time:0.00845838ms
89+
out_f16_th: [0.47387695, 0.05758667], time:0.00918818ms
90+
-------------------------------------------------------------------------------------
91+
-------------------------------------------------------------------------------------
92+
S=2048, K=1024
93+
out_f32: [1.3562144, 0.40408486], time:0.03009892ms
94+
out_f32x4: [1.3562144, 0.40408486], time:0.02289677ms
95+
out_f32_th: [1.3562144, 0.40408486], time:0.00921512ms
96+
-------------------------------------------------------------------------------------
97+
out_f16: [1.35644531, 0.40405273], time:0.01173806ms
98+
out_f16x2: [1.35644531, 0.40405273], time:0.00565076ms
99+
out_f16x8: [1.35644531, 0.40405273], time:0.00502610ms
100+
out_f16x8pack: [1.35644531, 0.40405273], time:0.00457048ms
101+
out_f16_th: [1.35644531, 0.40429688], time:0.00904894ms
102+
-------------------------------------------------------------------------------------
103+
-------------------------------------------------------------------------------------
104+
S=2048, K=2048
105+
out_f32: [-0.16498716, -0.15077244], time:0.04273534ms
106+
out_f32x4: [-0.16498716, -0.15077244], time:0.04386163ms
107+
out_f32_th: [-0.16498716, -0.15077244], time:0.00913596ms
108+
-------------------------------------------------------------------------------------
109+
out_f16: [-0.16516113, -0.15075684], time:0.01495862ms
110+
out_f16x2: [-0.16516113, -0.15075684], time:0.01407337ms
111+
out_f16x8: [-0.16516113, -0.15075684], time:0.00796247ms
112+
out_f16x8pack: [-0.16516113, -0.15075684], time:0.00734925ms
113+
out_f16_th: [-0.16503906, -0.15075684], time:0.00917435ms
114+
-------------------------------------------------------------------------------------
115+
-------------------------------------------------------------------------------------
116+
S=2048, K=4096
117+
out_f32: [-0.03888749, 0.32139146], time:0.08363676ms
118+
out_f32x4: [-0.03888749, 0.32139146], time:0.09505510ms
119+
out_f32_th: [-0.03888749, 0.32139146], time:0.04022837ms
120+
-------------------------------------------------------------------------------------
121+
out_f16: [-0.03887939, 0.3215332], time:0.02813959ms
122+
out_f16x2: [-0.03887939, 0.3215332], time:0.01906514ms
123+
out_f16x8: [-0.03887939, 0.3215332], time:0.01664281ms
124+
out_f16x8pack: [-0.03887939, 0.3215332], time:0.01474833ms
125+
out_f16_th: [-0.03887939, 0.32128906], time:0.01357365ms
126+
-------------------------------------------------------------------------------------
127+
-------------------------------------------------------------------------------------
128+
S=4096, K=1024
129+
out_f32: [-0.13875209, 1.08477271], time:0.05790567ms
130+
out_f32x4: [-0.13875209, 1.08477271], time:0.04317236ms
131+
out_f32_th: [-0.13875209, 1.08477271], time:0.00910425ms
132+
-------------------------------------------------------------------------------------
133+
out_f16: [-0.13903809, 1.08496094], time:0.02198315ms
134+
out_f16x2: [-0.13903809, 1.08496094], time:0.00964355ms
135+
out_f16x8: [-0.13903809, 1.08496094], time:0.00780869ms
136+
out_f16x8pack: [-0.13903809, 1.08496094], time:0.00729132ms
137+
out_f16_th: [-0.13879395, 1.08496094], time:0.00926042ms
138+
-------------------------------------------------------------------------------------
139+
-------------------------------------------------------------------------------------
140+
S=4096, K=2048
141+
out_f32: [0.82045084, -0.0894338], time:0.08363843ms
142+
out_f32x4: [0.82045084, -0.0894338], time:0.08431888ms
143+
out_f32_th: [0.82045084, -0.0894338], time:0.03837347ms
144+
-------------------------------------------------------------------------------------
145+
out_f16: [0.8203125, -0.08947754], time:0.02813506ms
146+
out_f16x2: [0.8203125, -0.08947754], time:0.02643061ms
147+
out_f16x8: [0.8203125, -0.08947754], time:0.01383305ms
148+
out_f16x8pack: [0.8203125, -0.08947754], time:0.01273918ms
149+
out_f16_th: [0.82080078, -0.0894165], time:0.01357722ms
150+
-------------------------------------------------------------------------------------
151+
-------------------------------------------------------------------------------------
152+
S=4096, K=4096
153+
out_f32: [-0.06997654, -0.16092129], time:0.19113564ms
154+
out_f32x4: [-0.06997654, -0.16092129], time:0.20371628ms
155+
out_f32_th: [-0.06997654, -0.16092129], time:0.20496607ms
156+
-------------------------------------------------------------------------------------
157+
out_f16: [-0.07012939, -0.16113281], time:0.05451322ms
158+
out_f16x2: [-0.07012939, -0.16113281], time:0.03633785ms
159+
out_f16x8: [-0.07012939, -0.16113281], time:0.03115463ms
160+
out_f16x8pack: [-0.07012939, -0.16113281], time:0.02735877ms
161+
out_f16_th: [-0.07000732, -0.16088867], time:0.03889561ms
162+
-------------------------------------------------------------------------------------
163+
```

0 commit comments

Comments
 (0)