|
| 1 | +# Swish |
| 2 | + |
| 3 | +## 0x00 说明 |
| 4 | + |
| 5 | +包含以下内容: |
| 6 | + |
| 7 | +- [X] swish_f32_kernel |
| 8 | +- [X] swish_f32x4_kernel(float4向量化版本) |
| 9 | +- [X] swish_f16_kernel(fp16版本) |
| 10 | +- [X] swish_f16x2_kernel(fp16向量化版本) |
| 11 | +- [X] swish_f16x8_kernel(fp16向量化版本) |
| 12 | +- [X] swish_f16x8_pack_kernel(fp16向量化,pack版本) |
| 13 | +- [X] PyTorch bindings |
| 14 | + |
| 15 | + |
| 16 | +## 测试 |
| 17 | + |
| 18 | +```bash |
| 19 | +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... |
| 20 | +export TORCH_CUDA_ARCH_LIST=Ada |
| 21 | +python3 swish.py |
| 22 | +``` |
| 23 | + |
| 24 | +输出: |
| 25 | + |
| 26 | +```bash |
| 27 | +------------------------------------------------------------------------------------- |
| 28 | + S=1024, K=1024 |
| 29 | + out_f32: ['0.46177661 ', '-0.10888041 '], time:0.01246500ms |
| 30 | + out_f32x4: ['0.46177661 ', '-0.10888041 '], time:0.01006508ms |
| 31 | + out_f32_th: ['0.46177667 ', '-0.10888041 '], time:0.03012419ms |
| 32 | +------------------------------------------------------------------------------------- |
| 33 | + out_f16: ['0.46191406 ', '-0.10894775 '], time:0.01299334ms |
| 34 | + out_f16x2: ['0.46191406 ', '-0.10894775 '], time:0.01036119ms |
| 35 | + out_f16x8: ['0.46191406 ', '-0.10894775 '], time:0.00979590ms |
| 36 | + out_f16x8pack: ['0.46191406 ', '-0.10894775 '], time:0.00972557ms |
| 37 | + out_f16_th: ['0.46191406 ', '-0.10888672 '], time:0.02423882ms |
| 38 | +------------------------------------------------------------------------------------- |
| 39 | +------------------------------------------------------------------------------------- |
| 40 | + S=1024, K=2048 |
| 41 | + out_f32: ['-0.27797085 ', '0.71514565 '], time:0.01415992ms |
| 42 | + out_f32x4: ['-0.27797085 ', '0.71514565 '], time:0.01159716ms |
| 43 | + out_f32_th: ['-0.27797085 ', '0.71514559 '], time:0.02964258ms |
| 44 | +------------------------------------------------------------------------------------- |
| 45 | + out_f16: ['-0.27807617 ', '0.71582031 '], time:0.01473880ms |
| 46 | + out_f16x2: ['-0.27807617 ', '0.71582031 '], time:0.01404881ms |
| 47 | + out_f16x8: ['-0.27807617 ', '0.71582031 '], time:0.01127148ms |
| 48 | + out_f16x8pack: ['-0.27807617 ', '0.71582031 '], time:0.01101518ms |
| 49 | + out_f16_th: ['-0.27807617 ', '0.71533203 '], time:0.02657008ms |
| 50 | +------------------------------------------------------------------------------------- |
| 51 | +------------------------------------------------------------------------------------- |
| 52 | + S=1024, K=4096 |
| 53 | + out_f32: ['0.29988611 ', '-0.2541697 '], time:0.01959276ms |
| 54 | + out_f32x4: ['0.29988611 ', '-0.2541697 '], time:0.01605868ms |
| 55 | + out_f32_th: ['0.29988611 ', '-0.25416973 '], time:0.03745818ms |
| 56 | +------------------------------------------------------------------------------------- |
| 57 | + out_f16: ['0.30004883 ', '-0.25415039 '], time:0.02078271ms |
| 58 | + out_f16x2: ['0.30004883 ', '-0.25415039 '], time:0.01729155ms |
| 59 | + out_f16x8: ['0.30004883 ', '-0.25415039 '], time:0.01489425ms |
| 60 | + out_f16x8pack: ['0.30004883 ', '-0.25415039 '], time:0.01351643ms |
| 61 | + out_f16_th: ['0.29980469 ', '-0.25415039 '], time:0.03149080ms |
| 62 | +------------------------------------------------------------------------------------- |
| 63 | +------------------------------------------------------------------------------------- |
| 64 | + S=2048, K=1024 |
| 65 | + out_f32: ['-0.07777861 ', '-0.27842814 '], time:0.01640201ms |
| 66 | + out_f32x4: ['-0.07777861 ', '-0.27842814 '], time:0.01180029ms |
| 67 | + out_f32_th: ['-0.07777861 ', '-0.27842814 '], time:0.02952218ms |
| 68 | +------------------------------------------------------------------------------------- |
| 69 | + out_f16: ['-0.07775879 ', '-0.27856445 '], time:0.01758027ms |
| 70 | + out_f16x2: ['-0.07775879 ', '-0.27856445 '], time:0.01236153ms |
| 71 | + out_f16x8: ['-0.07775879 ', '-0.27856445 '], time:0.01109338ms |
| 72 | + out_f16x8pack: ['-0.07775879 ', '-0.27856445 '], time:0.01091790ms |
| 73 | + out_f16_th: ['-0.07775879 ', '-0.27856445 '], time:0.02657914ms |
| 74 | +------------------------------------------------------------------------------------- |
| 75 | +------------------------------------------------------------------------------------- |
| 76 | + S=2048, K=2048 |
| 77 | + out_f32: ['-0.14754841 ', '-0.21989606 '], time:0.01957679ms |
| 78 | + out_f32x4: ['-0.14754841 ', '-0.21989606 '], time:0.01496792ms |
| 79 | + out_f32_th: ['-0.14754841 ', '-0.21989603 '], time:0.03751612ms |
| 80 | +------------------------------------------------------------------------------------- |
| 81 | + out_f16: ['-0.14758301 ', '-0.21984863 '], time:0.02085924ms |
| 82 | + out_f16x2: ['-0.14758301 ', '-0.21984863 '], time:0.01961517ms |
| 83 | + out_f16x8: ['-0.14758301 ', '-0.21984863 '], time:0.01386237ms |
| 84 | + out_f16x8pack: ['-0.14758301 ', '-0.21984863 '], time:0.01334929ms |
| 85 | + out_f16_th: ['-0.14758301 ', '-0.21984863 '], time:0.03151488ms |
| 86 | +------------------------------------------------------------------------------------- |
| 87 | +------------------------------------------------------------------------------------- |
| 88 | + S=2048, K=4096 |
| 89 | + out_f32: ['1.07876182 ', '-0.27844051 '], time:0.03036070ms |
| 90 | + out_f32x4: ['1.07876182 ', '-0.27844051 '], time:0.02339220ms |
| 91 | + out_f32_th: ['1.07876182 ', '-0.27844048 '], time:0.05310464ms |
| 92 | +------------------------------------------------------------------------------------- |
| 93 | + out_f16: ['1.078125 ', '-0.27832031 '], time:0.03291988ms |
| 94 | + out_f16x2: ['1.078125 ', '-0.27832031 '], time:0.02590466ms |
| 95 | + out_f16x8: ['1.078125 ', '-0.27832031 '], time:0.02027988ms |
| 96 | + out_f16x8pack: ['1.078125 ', '-0.27832031 '], time:0.01811814ms |
| 97 | + out_f16_th: ['1.07910156 ', '-0.27832031 '], time:0.04083204ms |
| 98 | +------------------------------------------------------------------------------------- |
| 99 | +------------------------------------------------------------------------------------- |
| 100 | + S=4096, K=1024 |
| 101 | + out_f32: ['0.31169948 ', '-0.18232882 '], time:0.02427077ms |
| 102 | + out_f32x4: ['0.31169948 ', '-0.18232882 '], time:0.01515222ms |
| 103 | + out_f32_th: ['0.31169948 ', '-0.18232881 '], time:0.03754425ms |
| 104 | +------------------------------------------------------------------------------------- |
| 105 | + out_f16: ['0.31152344 ', '-0.18237305 '], time:0.02679300ms |
| 106 | + out_f16x2: ['0.31152344 ', '-0.18237305 '], time:0.01617312ms |
| 107 | + out_f16x8: ['0.31152344 ', '-0.18237305 '], time:0.01357770ms |
| 108 | + out_f16x8pack: ['0.31152344 ', '-0.18237305 '], time:0.01324248ms |
| 109 | + out_f16_th: ['0.31152344 ', '-0.18225098 '], time:0.03149295ms |
| 110 | +------------------------------------------------------------------------------------- |
| 111 | +------------------------------------------------------------------------------------- |
| 112 | + S=4096, K=2048 |
| 113 | + out_f32: ['1.5033319 ', '0.17473438 '], time:0.03030729ms |
| 114 | + out_f32x4: ['1.5033319 ', '0.17473438 '], time:0.02150083ms |
| 115 | + out_f32_th: ['1.5033319 ', '0.17473438 '], time:0.05257607ms |
| 116 | +------------------------------------------------------------------------------------- |
| 117 | + out_f16: ['1.50390625 ', '0.17468262 '], time:0.03289509ms |
| 118 | + out_f16x2: ['1.50390625 ', '0.17468262 '], time:0.03073120ms |
| 119 | + out_f16x8: ['1.50390625 ', '0.17468262 '], time:0.01862860ms |
| 120 | + out_f16x8pack: ['1.50390625 ', '0.17468262 '], time:0.01772857ms |
| 121 | + out_f16_th: ['1.50390625 ', '0.17468262 '], time:0.04082441ms |
| 122 | +------------------------------------------------------------------------------------- |
| 123 | +------------------------------------------------------------------------------------- |
| 124 | + S=4096, K=4096 |
| 125 | + out_f32: ['-0.05288643 ', '-0.14218464 '], time:0.19254756ms |
| 126 | + out_f32x4: ['-0.05288643 ', '-0.14218464 '], time:0.19258785ms |
| 127 | + out_f32_th: ['-0.05288643 ', '-0.14218464 '], time:0.48660636ms |
| 128 | +------------------------------------------------------------------------------------- |
| 129 | + out_f16: ['-0.052948 ', '-0.14221191 '], time:0.05689216ms |
| 130 | + out_f16x2: ['-0.052948 ', '-0.14221191 '], time:0.04335928ms |
| 131 | + out_f16x8: ['-0.052948 ', '-0.14221191 '], time:0.03096652ms |
| 132 | + out_f16x8pack: ['-0.052948 ', '-0.14221191 '], time:0.02706647ms |
| 133 | + out_f16_th: ['-0.05288696 ', '-0.14221191 '], time:0.05971408ms |
| 134 | +------------------------------------------------------------------------------------- |
| 135 | + |
| 136 | +``` |
0 commit comments