|
| 1 | +# GELU |
| 2 | + |
| 3 | +## 0x00 说明 |
| 4 | + |
| 5 | +包含以下内容: |
| 6 | + |
| 7 | +- [X] gelu_f32_kernel |
| 8 | +- [X] gelu_f32x4_kernel(float4向量化版本) |
| 9 | +- [X] gelu_f16_kernel |
| 10 | +- [X] gelu_f16x2_kernel(half2向量化) |
| 11 | +- [X] gelu_f16x8_kernel(unpack版本) |
| 12 | +- [X] gelu_f16x8_pack_kernel(pack版本) |
| 13 | +- [X] PyTorch bindings |
| 14 | + |
| 15 | + |
| 16 | +## 测试 |
| 17 | + |
| 18 | +对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可: |
| 19 | + |
| 20 | +```c++ |
| 21 | +y[idx] = HALF_GELU_OPS(__half2float(v)); // line 96 |
| 22 | +reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); // line 109 , line 110 |
| 23 | +reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y)); |
| 24 | +``` |
| 25 | +测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了): |
| 26 | +```bash |
| 27 | +------------------------------------------------------------------------------------- |
| 28 | + S=2048, K=4096 |
| 29 | + out_f32: [-0.08196318, -0.1613517], time:0.13425708ms |
| 30 | + out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms |
| 31 | + out_f32_th: [-0.08196313, -0.1613517], time:0.08195782ms |
| 32 | +------------------------------------------------------------------------------------- |
| 33 | + out_f16: [-0.08197021, -0.16137695], time:0.12120271ms |
| 34 | + out_f16x2: [-0.08197021, -0.16137695], time:0.12122369ms |
| 35 | + out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms |
| 36 | + out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms |
| 37 | + out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms |
| 38 | +------------------------------------------------------------------------------------- |
| 39 | +``` |
| 40 | +相关参考: |
| 41 | +- [pytorch-c10-BFloat16.h](https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h) |
| 42 | +- [math ptx](https://github.com/pavanky/math_ptx) |
| 43 | + |
| 44 | +此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。 |
| 45 | + |
| 46 | +```bash |
| 47 | +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... |
| 48 | +export TORCH_CUDA_ARCH_LIST=Ada |
| 49 | +python3 gelu.py |
| 50 | +``` |
| 51 | + |
| 52 | +输出(不做类型转换导致half误差): |
| 53 | + |
| 54 | +```bash |
| 55 | +------------------------------------------------------------------------------------- |
| 56 | + S=1024, K=1024 |
| 57 | + out_f32: [-0.13358943, -0.06881647], time:0.01621890ms |
| 58 | + out_f32x4: [-0.13358943, -0.06881647], time:0.01278400ms |
| 59 | + out_f32_th: [-0.13358943, -0.06881647], time:0.00897789ms |
| 60 | +------------------------------------------------------------------------------------- |
| 61 | + out_f16: [-0.13378906, -0.06884766], time:0.00663781ms |
| 62 | + out_f16x2: [-0.13378906, -0.06884766], time:0.00366306ms |
| 63 | + out_f16x8: [-0.13378906, -0.06884766], time:0.00343323ms |
| 64 | + out_f16x8pack: [-0.13378906, -0.06884766], time:0.00331473ms |
| 65 | + out_f16_th: [-0.13354492, -0.06884766], time:0.00907278ms |
| 66 | +------------------------------------------------------------------------------------- |
| 67 | +------------------------------------------------------------------------------------- |
| 68 | + S=1024, K=2048 |
| 69 | + out_f32: [1.38783729, -0.06707606], time:0.02223682ms |
| 70 | + out_f32x4: [1.38783729, -0.06707606], time:0.02367806ms |
| 71 | + out_f32_th: [1.38783729, -0.06707606], time:0.00959325ms |
| 72 | +------------------------------------------------------------------------------------- |
| 73 | + out_f16: [1.38769531, -0.06713867], time:0.00834370ms |
| 74 | + out_f16x2: [1.38769531, -0.06713867], time:0.00784707ms |
| 75 | + out_f16x8: [1.38769531, -0.06713867], time:0.00499964ms |
| 76 | + out_f16x8pack: [1.38769531, -0.06713867], time:0.00461078ms |
| 77 | + out_f16_th: [1.38769531, -0.06707764], time:0.00895357ms |
| 78 | +------------------------------------------------------------------------------------- |
| 79 | +------------------------------------------------------------------------------------- |
| 80 | + S=1024, K=4096 |
| 81 | + out_f32: [0.47386399, 0.05760021], time:0.04273629ms |
| 82 | + out_f32x4: [0.47386399, 0.05760021], time:0.05011940ms |
| 83 | + out_f32_th: [0.47386405, 0.05760022], time:0.00933146ms |
| 84 | +------------------------------------------------------------------------------------- |
| 85 | + out_f16: [0.47387695, 0.05761719], time:0.01495123ms |
| 86 | + out_f16x2: [0.47387695, 0.05761719], time:0.01039743ms |
| 87 | + out_f16x8: [0.47387695, 0.05761719], time:0.00936055ms |
| 88 | + out_f16x8pack: [0.47387695, 0.05761719], time:0.00845838ms |
| 89 | + out_f16_th: [0.47387695, 0.05758667], time:0.00918818ms |
| 90 | +------------------------------------------------------------------------------------- |
| 91 | +------------------------------------------------------------------------------------- |
| 92 | + S=2048, K=1024 |
| 93 | + out_f32: [1.3562144, 0.40408486], time:0.03009892ms |
| 94 | + out_f32x4: [1.3562144, 0.40408486], time:0.02289677ms |
| 95 | + out_f32_th: [1.3562144, 0.40408486], time:0.00921512ms |
| 96 | +------------------------------------------------------------------------------------- |
| 97 | + out_f16: [1.35644531, 0.40405273], time:0.01173806ms |
| 98 | + out_f16x2: [1.35644531, 0.40405273], time:0.00565076ms |
| 99 | + out_f16x8: [1.35644531, 0.40405273], time:0.00502610ms |
| 100 | + out_f16x8pack: [1.35644531, 0.40405273], time:0.00457048ms |
| 101 | + out_f16_th: [1.35644531, 0.40429688], time:0.00904894ms |
| 102 | +------------------------------------------------------------------------------------- |
| 103 | +------------------------------------------------------------------------------------- |
| 104 | + S=2048, K=2048 |
| 105 | + out_f32: [-0.16498716, -0.15077244], time:0.04273534ms |
| 106 | + out_f32x4: [-0.16498716, -0.15077244], time:0.04386163ms |
| 107 | + out_f32_th: [-0.16498716, -0.15077244], time:0.00913596ms |
| 108 | +------------------------------------------------------------------------------------- |
| 109 | + out_f16: [-0.16516113, -0.15075684], time:0.01495862ms |
| 110 | + out_f16x2: [-0.16516113, -0.15075684], time:0.01407337ms |
| 111 | + out_f16x8: [-0.16516113, -0.15075684], time:0.00796247ms |
| 112 | + out_f16x8pack: [-0.16516113, -0.15075684], time:0.00734925ms |
| 113 | + out_f16_th: [-0.16503906, -0.15075684], time:0.00917435ms |
| 114 | +------------------------------------------------------------------------------------- |
| 115 | +------------------------------------------------------------------------------------- |
| 116 | + S=2048, K=4096 |
| 117 | + out_f32: [-0.03888749, 0.32139146], time:0.08363676ms |
| 118 | + out_f32x4: [-0.03888749, 0.32139146], time:0.09505510ms |
| 119 | + out_f32_th: [-0.03888749, 0.32139146], time:0.04022837ms |
| 120 | +------------------------------------------------------------------------------------- |
| 121 | + out_f16: [-0.03887939, 0.3215332], time:0.02813959ms |
| 122 | + out_f16x2: [-0.03887939, 0.3215332], time:0.01906514ms |
| 123 | + out_f16x8: [-0.03887939, 0.3215332], time:0.01664281ms |
| 124 | + out_f16x8pack: [-0.03887939, 0.3215332], time:0.01474833ms |
| 125 | + out_f16_th: [-0.03887939, 0.32128906], time:0.01357365ms |
| 126 | +------------------------------------------------------------------------------------- |
| 127 | +------------------------------------------------------------------------------------- |
| 128 | + S=4096, K=1024 |
| 129 | + out_f32: [-0.13875209, 1.08477271], time:0.05790567ms |
| 130 | + out_f32x4: [-0.13875209, 1.08477271], time:0.04317236ms |
| 131 | + out_f32_th: [-0.13875209, 1.08477271], time:0.00910425ms |
| 132 | +------------------------------------------------------------------------------------- |
| 133 | + out_f16: [-0.13903809, 1.08496094], time:0.02198315ms |
| 134 | + out_f16x2: [-0.13903809, 1.08496094], time:0.00964355ms |
| 135 | + out_f16x8: [-0.13903809, 1.08496094], time:0.00780869ms |
| 136 | + out_f16x8pack: [-0.13903809, 1.08496094], time:0.00729132ms |
| 137 | + out_f16_th: [-0.13879395, 1.08496094], time:0.00926042ms |
| 138 | +------------------------------------------------------------------------------------- |
| 139 | +------------------------------------------------------------------------------------- |
| 140 | + S=4096, K=2048 |
| 141 | + out_f32: [0.82045084, -0.0894338], time:0.08363843ms |
| 142 | + out_f32x4: [0.82045084, -0.0894338], time:0.08431888ms |
| 143 | + out_f32_th: [0.82045084, -0.0894338], time:0.03837347ms |
| 144 | +------------------------------------------------------------------------------------- |
| 145 | + out_f16: [0.8203125, -0.08947754], time:0.02813506ms |
| 146 | + out_f16x2: [0.8203125, -0.08947754], time:0.02643061ms |
| 147 | + out_f16x8: [0.8203125, -0.08947754], time:0.01383305ms |
| 148 | + out_f16x8pack: [0.8203125, -0.08947754], time:0.01273918ms |
| 149 | + out_f16_th: [0.82080078, -0.0894165], time:0.01357722ms |
| 150 | +------------------------------------------------------------------------------------- |
| 151 | +------------------------------------------------------------------------------------- |
| 152 | + S=4096, K=4096 |
| 153 | + out_f32: [-0.06997654, -0.16092129], time:0.19113564ms |
| 154 | + out_f32x4: [-0.06997654, -0.16092129], time:0.20371628ms |
| 155 | + out_f32_th: [-0.06997654, -0.16092129], time:0.20496607ms |
| 156 | +------------------------------------------------------------------------------------- |
| 157 | + out_f16: [-0.07012939, -0.16113281], time:0.05451322ms |
| 158 | + out_f16x2: [-0.07012939, -0.16113281], time:0.03633785ms |
| 159 | + out_f16x8: [-0.07012939, -0.16113281], time:0.03115463ms |
| 160 | + out_f16x8pack: [-0.07012939, -0.16113281], time:0.02735877ms |
| 161 | + out_f16_th: [-0.07000732, -0.16088867], time:0.03889561ms |
| 162 | +------------------------------------------------------------------------------------- |
| 163 | +``` |
0 commit comments