|
26 | 26 | - [X] hgemm_mma_m16n8k16_naive(MMA)
|
27 | 27 | - [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack)
|
28 | 28 | - [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
|
29 |
| -- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Reg Double Buffers) |
| 29 | +- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Warp swizzle, Reg Double Buffers, Collective Store with Reg Reuse & Warp Shuffle) |
30 | 30 | - [X] PyTorch bindings
|
31 | 31 |
|
32 | 32 | ## 目前性能
|
33 | 33 |
|
34 | 34 | - NVIDIA L20
|
35 | 35 |
|
36 |
| -目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)。 |
| 36 | +目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle,[点击查看性能数据](#NV-L20)。 |
37 | 37 |
|
38 | 38 | - NVIDIA GeForce RTX 3080 Laptop
|
39 | 39 |
|
40 |
| -在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 MMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,[点击查看性能数据](#NV-RTX-3080)。 |
| 40 | +在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,[点击查看性能数据](#NV-RTX-3080)。 |
41 | 41 |
|
42 | 42 | ## 共享内存 Bank Conflicts
|
43 | 43 |
|
@@ -244,6 +244,108 @@ python3 hgemm.py --wmma --wmma-all # test all wmma kernels for all MNK
|
244 | 244 | python3 hgemm.py --mma --mma-all # test all mma kernels for all MNK
|
245 | 245 | ```
|
246 | 246 |
|
| 247 | +示例1: |
| 248 | +```bash |
| 249 | +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma |
| 250 | +Namespace(M=16384, N=16384, K=8192, warmup=2, iters=10, show_all=False, enable_mma=True, enable_wmma=False, enable_cuda=False, enable_mma_all=False, enable_wmma_all=False, enable_cuda_all=False, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25) |
| 251 | +Loading hgemm lib ... |
| 252 | +pre allocate for fast profiling start, MAX_M=16384, MAX_N=16384, MAX_K=8192 |
| 253 | +pre allocate for fast profiling done, time: 21829.665184020996 ms |
| 254 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 255 | + M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1 |
| 256 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 257 | +--------------------------------------------------------------------MMA----------------------------------------------------------- |
| 258 | + (mma2x4+warp4x4+stage3+swizzle): ['55.53125 ', '-34.4375 '], time:39.08894ms, swizzle: 4096, TFLOPS: 112.51(+0.00%) |
| 259 | + (mma2x4+warp4x4+stage2+swizzle): ['55.53125 ', '-34.4375 '], time:38.40720ms, swizzle: 4096, TFLOPS: 114.51(+1.78%) |
| 260 | + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:39.23299ms, swizzle: 4096, TFLOPS: 112.10 |
| 261 | + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.20564ms, swizzle: 4096, TFLOPS: 115.12(+0.53%) |
| 262 | + (mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.67657ms, swizzle: 4096, TFLOPS: 113.71 |
| 263 | + (mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:40.10882ms, swizzle: 4096, TFLOPS: 109.65 |
| 264 | + (mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.44747ms, swizzle: 4096, TFLOPS: 114.39 |
| 265 | + (cublas): ['55.53125 ', '-34.4375 '], time:37.43820ms, swizzle: NOOP, TFLOPS: 117.47(+2.05%) |
| 266 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 267 | +``` |
| 268 | +示例2: |
| 269 | +```bash |
| 270 | +python3 hgemm.py --M 4096 --N 4096 --K 4096 --mma-all |
| 271 | +Namespace(M=4096, N=4096, K=4096, warmup=2, iters=10, show_all=False, enable_mma=False, enable_wmma=False, enable_cuda=False, enable_mma_all=True, enable_wmma_all=False, enable_cuda_all=False, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25) |
| 272 | +Loading hgemm lib ... |
| 273 | +pre allocate for fast profiling start, MAX_M=4096, MAX_N=4096, MAX_K=4096 |
| 274 | +pre allocate for fast profiling done, time: 2056.009292602539 ms |
| 275 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 276 | + M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1 |
| 277 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 278 | +--------------------------------------------------------------------MMA----------------------------------------------------------- |
| 279 | + (mma2x4+warp4x4): ['131.625 ', '23.59375 '], time:1.412987ms, swizzle: NOOP, TFLOPS: 97.27 (+0.00%) |
| 280 | + (mma2x4+warp4x4+stage3): ['131.625 ', '23.59375 '], time:1.343512ms, swizzle: NOOP, TFLOPS: 102.30(+5.17%) |
| 281 | + (mma2x4+warp4x4+stage2): ['131.625 ', '23.59375 '], time:1.326799ms, swizzle: NOOP, TFLOPS: 103.59(+1.26%) |
| 282 | + (mma2x4+warp4x4+stage3+dsmem): ['131.625 ', '23.59375 '], time:1.350784ms, swizzle: NOOP, TFLOPS: 101.75 |
| 283 | + (mma2x4+warp4x4+stage2+dsmem): ['131.625 ', '23.59375 '], time:1.326084ms, swizzle: NOOP, TFLOPS: 103.64(+0.05%) |
| 284 | + (mma2x4+warp4x4x2+stage4+dsmem): ['131.625 ', '23.59375 '], time:1.324439ms, swizzle: NOOP, TFLOPS: 103.77(+0.12%) |
| 285 | + (mma2x4+warp4x4x2+stage3+dsmem): ['131.625 ', '23.59375 '], time:1.369738ms, swizzle: NOOP, TFLOPS: 100.34 |
| 286 | + (mma2x4+warp4x4x2+stage2+dsmem): ['131.625 ', '23.59375 '], time:1.299858ms, swizzle: NOOP, TFLOPS: 105.73(+1.89%) |
| 287 | + (mma2x4+warp4x4+stage3+swizzle): ['131.625 ', '23.59375 '], time:1.344513ms, swizzle: 1024, TFLOPS: 102.22 |
| 288 | + (mma2x4+warp4x4+stage2+swizzle): ['131.625 ', '23.59375 '], time:1.324009ms, swizzle: 1024, TFLOPS: 103.81 |
| 289 | + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.349854ms, swizzle: 1024, TFLOPS: 101.82 |
| 290 | + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.318955ms, swizzle: 1024, TFLOPS: 104.20 |
| 291 | + (mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.318430ms, swizzle: 1024, TFLOPS: 104.24 |
| 292 | + (mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.371240ms, swizzle: 1024, TFLOPS: 100.23 |
| 293 | + (mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.300096ms, swizzle: 1024, TFLOPS: 105.71 |
| 294 | + (cublas): ['131.625 ', '23.59375 '], time:1.289224ms, swizzle: NOOP, TFLOPS: 106.61(+0.82%) |
| 295 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 296 | +``` |
| 297 | +示例3: |
| 298 | +```bash |
| 299 | +python3 hgemm.py --M 4096 --N 4096 --K 4096 --mma-all --wmma-all --cuda-all |
| 300 | +Namespace(M=4096, N=4096, K=4096, warmup=2, iters=10, show_all=False, enable_mma=False, enable_wmma=False, enable_cuda=False, enable_mma_all=True, enable_wmma_all=True, enable_cuda_all=True, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25) |
| 301 | +Loading hgemm lib ... |
| 302 | +pre allocate for fast profiling start, MAX_M=4096, MAX_N=4096, MAX_K=4096 |
| 303 | +pre allocate for fast profiling done, time: 2048.0010509490967 ms |
| 304 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 305 | + M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1 |
| 306 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 307 | + (naive): ['-3.5371093', '-101.0 '], time:37.66887ms, swizzle: NOOP, TFLOPS: 3.65 (+0.00%) |
| 308 | + (f16x8pack+t8x8+bcf): ['-3.5371093', '-101.0 '], time:2.811360ms, swizzle: NOOP, TFLOPS: 48.89 (+1239.88%) |
| 309 | + (f16x8pack+t8x8+dbuf): ['-3.5371093', '-101.0 '], time:2.815437ms, swizzle: NOOP, TFLOPS: 48.82 |
| 310 | + (f16x8pack+t8x8+k16+dbuf): ['-3.5371093', '-101.0 '], time:2.634835ms, swizzle: NOOP, TFLOPS: 52.16 (+6.70%) |
| 311 | +--------------------------------------------------------------------WMMA---------------------------------------------------------- |
| 312 | + (mma4x2): ['-3.3847656', '-101.375 '], time:2.942705ms, swizzle: NOOP, TFLOPS: 46.70 |
| 313 | + (mma4x2+warp2x4): ['-3.3847656', '-101.375 '], time:1.817488ms, swizzle: NOOP, TFLOPS: 75.62 (+44.97%) |
| 314 | + (mma4x2+warp2x4+stage3): ['-3.3847656', '-101.375 '], time:1.355123ms, swizzle: NOOP, TFLOPS: 101.42(+34.12%) |
| 315 | + (mma4x2+warp2x4+stage2): ['-3.3847656', '-101.375 '], time:1.343965ms, swizzle: NOOP, TFLOPS: 102.26(+0.83%) |
| 316 | + (mma4x2+warp2x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.342964ms, swizzle: NOOP, TFLOPS: 102.34(+0.07%) |
| 317 | + (mma4x2+warp2x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.343178ms, swizzle: NOOP, TFLOPS: 102.32 |
| 318 | + (mma4x2+warp2x4+stage3+swizzle): ['-3.3847656', '-101.375 '], time:1.345729ms, swizzle: 1024, TFLOPS: 102.13 |
| 319 | + (mma4x2+warp2x4+stage2+swizzle): ['-3.3847656', '-101.375 '], time:1.324367ms, swizzle: 1024, TFLOPS: 103.78(+1.40%) |
| 320 | + (mma4x2+warp2x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.351284ms, swizzle: 1024, TFLOPS: 101.71 |
| 321 | + (mma4x2+warp2x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.324582ms, swizzle: 1024, TFLOPS: 103.76 |
| 322 | + (mma4x4+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.387619ms, swizzle: NOOP, TFLOPS: 99.05 |
| 323 | + (mma4x4+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.490569ms, swizzle: NOOP, TFLOPS: 92.21 |
| 324 | + (mma4x2+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.376056ms, swizzle: NOOP, TFLOPS: 99.88 |
| 325 | + (mma4x2+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.425576ms, swizzle: NOOP, TFLOPS: 96.41 |
| 326 | + (mma4x4+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.395106ms, swizzle: 1024, TFLOPS: 98.52 |
| 327 | + (mma4x4+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.414942ms, swizzle: 1024, TFLOPS: 97.13 |
| 328 | + (mma4x2+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.377010ms, swizzle: 1024, TFLOPS: 99.81 |
| 329 | + (mma4x2+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.604509ms, swizzle: 1024, TFLOPS: 85.66 |
| 330 | +--------------------------------------------------------------------MMA----------------------------------------------------------- |
| 331 | + (mma2x4+warp4x4): ['-3.3847656', '-101.375 '], time:1.412653ms, swizzle: NOOP, TFLOPS: 97.29 |
| 332 | + (mma2x4+warp4x4+stage3): ['-3.3847656', '-101.375 '], time:1.343774ms, swizzle: NOOP, TFLOPS: 102.28 |
| 333 | + (mma2x4+warp4x4+stage2): ['-3.3847656', '-101.375 '], time:1.326417ms, swizzle: NOOP, TFLOPS: 103.62 |
| 334 | + (mma2x4+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.351308ms, swizzle: NOOP, TFLOPS: 101.71 |
| 335 | + (mma2x4+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.326489ms, swizzle: NOOP, TFLOPS: 103.61 |
| 336 | + (mma2x4+warp4x4x2+stage4+dsmem): ['-3.3847656', '-101.375 '], time:1.324319ms, swizzle: NOOP, TFLOPS: 103.78(+0.00%) |
| 337 | + (mma2x4+warp4x4x2+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.369786ms, swizzle: NOOP, TFLOPS: 100.34 |
| 338 | + (mma2x4+warp4x4x2+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.299762ms, swizzle: NOOP, TFLOPS: 105.74(+1.89%) |
| 339 | + (mma2x4+warp4x4+stage3+swizzle): ['-3.3847656', '-101.375 '], time:1.344013ms, swizzle: 1024, TFLOPS: 102.26 |
| 340 | + (mma2x4+warp4x4+stage2+swizzle): ['-3.3847656', '-101.375 '], time:1.324701ms, swizzle: 1024, TFLOPS: 103.75 |
| 341 | + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.348972ms, swizzle: 1024, TFLOPS: 101.88 |
| 342 | + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.318597ms, swizzle: 1024, TFLOPS: 104.23 |
| 343 | + (mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.318240ms, swizzle: 1024, TFLOPS: 104.26 |
| 344 | + (mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.370477ms, swizzle: 1024, TFLOPS: 100.29 |
| 345 | + (mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.300477ms, swizzle: 1024, TFLOPS: 105.68 |
| 346 | + (cublas): ['-3.3847656', '-101.375 '], time:1.289367ms, swizzle: NOOP, TFLOPS: 106.59(+0.81%) |
| 347 | +---------------------------------------------------------------------------------------------------------------------------------- |
| 348 | +``` |
247 | 349 |
|
248 | 350 | ## NVIDIA L20
|
249 | 351 | <div id="NV-L20"></div>
|
|
0 commit comments