4
4
5
5
| CUDA Cores| Sliced K(Loop over K)| Tile Block| Tile Thread|
6
6
| :---:| :---:| :---:| :---:|
7
- | ✅ | ✅ | ✅ | ✅ |
8
- | ** WMMA(m16n16k16)** | ** MMA(m16n8k16)** | ** Pack LDST** | ** SMEM Padding** |
9
- | ✅ | ✅ | ✅ | ✅ |
7
+ | ✔️ | ✔️ | ✔️ | ✔️ |
8
+ | ** WMMA(m16n16k16)** | ** MMA(m16n8k16)** | ** Pack LDST(128 bits) ** | ** SMEM Padding** |
9
+ | ✔️ | ✔️ | ✔️ | ✔️ |
10
10
| ** Copy Async** | ** Tile MMA(More Threads)** | ** Tile Warp(More Values)** | ** Multi Stages** |
11
- | ✅ | ✅ | ✅ | ✅ |
12
- | ** Reg Double Buffers** | ** Block Swizzle** | ** Warp Swizzle** | ** Collective Store(Shuffle)** |
13
- | ✅ | ✅ | ✅ | ✅ |
11
+ | ✔️ | ✔️ | ✔️ | ✔️ |
12
+ | ** Reg Double Buffers** | ** Block Swizzle** | ** Warp Swizzle** | ** Collective Store(Reg Reuse&Warp Shuffle)** |
13
+ | ✔️ | ✔️ | ✔️ | ✔️ |
14
14
| ** Row Major(NN)** | ** Col Major(TN)** | ** SMEM Swizzle** | ...|
15
- | ✅ | ✅ | ❔| ...|
15
+ | ✔️ | ✔️ | ❔| ...|
16
16
17
17
<details >
18
18
<summary > 🔑️ 点击查看所有支持的HGEMM Kernels! </summary >
@@ -167,8 +167,15 @@ python3 hgemm.py --M 4096 --N 4096 --K 4096 --mma-all --wmma-all --cuda-all
167
167
168
168
### PyTorch HGEMM Profile
169
169
170
- 在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn kernel,内部实际使用HMMA(Tensor Cores)进行计算,在3080上profile发现使用sm80_xmma_gemm_f16f16_f16f32_f32_nn_n_tilesize96x64x32_stage3_warpsize2x2x1_tensor16x8x16_kernel。因此,只有实现使用Tensor Cores的HGEMM,才有可能接近PyTorch/cuBLAS的性能。
171
-
170
+ 在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用:
171
+ ``` C++
172
+ ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn_kernel
173
+ ```
174
+ 内部实际使用HMMA(Tensor Cores)进行计算,在3080上profile发现使用:
175
+ ``` C++
176
+ sm80_xmma_gemm_f16f16_f16f32_f32_nn_n_tilesize96x64x32_stage3_warpsize2x2x1_tensor16x8x16_kernel
177
+ ```
178
+ 因此,只有实现使用Tensor Cores的HGEMM,才有可能接近PyTorch/cuBLAS的性能。
172
179
``` bash
173
180
ncu -o hgemm.prof -f python3 prof.py
174
181
nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py
@@ -183,8 +190,10 @@ nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true
183
190
...
184
191
```
185
192
186
- ### 共享内存 Bank Conflicts
193
+ ### SMEM Padding
187
194
195
+ #### Bank Conflicts的产生
196
+
188
197
含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict;
189
198
190
199
![ ] ( https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png )
@@ -206,6 +215,18 @@ cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
206
215
207
216
本仓库实现的HGEMM Double Buffers策略如下:1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可,对比非double buffers版本,总共节省了 ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。HFMA计算,从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于加载下一块BK需要的数据到共享内存;3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续HFMA及其它运算指令的 launch 执行,也就达到了Double Buffers的目的,具体代码见[hgemm.cu](./hgemm.cu)。
208
217
218
+ ### Tile Block
219
+
220
+ TODO
221
+
222
+ ### Tile Thread
223
+
224
+ TODO
225
+
226
+ ### Pack LDST 128 bits
227
+
228
+ TODO
229
+
209
230
### Async Copy
210
231
211
232
TODO
@@ -214,16 +235,33 @@ TODO
214
235
215
236
TODO
216
237
238
+ ### Tensor Cores(WMMA/MMA)
239
+
240
+ TODO
241
+
242
+ ### Tile MMA/Warp
243
+
244
+ TODO
245
+
217
246
### Thread Block Swizze
218
247
219
248
TODO
220
249
221
250
### Warp Swizzle
222
251
252
+ TODO
253
+
223
254
### Reg Double Buffers
224
255
225
256
TODO
226
257
258
+ ### Collective Store(Reg Reuse&Warp Shuffle)
259
+
260
+ TODO
261
+
262
+ ### SMEM Swizzle/Permuted
263
+
264
+ TODO
227
265
228
266
## 参考文献
229
267
0 commit comments