Skip to content

Commit 6c89595

Browse files
authored
[HGEMM] collective store via warp shfl&reg reuse (#101)
* Update hgemm_mma_stage.cu * Update hgemm.py * Update hgemm_mma_stage.cu * Create hgemm_mma_stage_col_major.cu * Update hgemm_mma_stage_col_major.cu * Update README.md * Update hgemm.py * Update README.md * Update README.md * Update hgemm.py
1 parent bcd12bd commit 6c89595

File tree

4 files changed

+262
-120
lines changed

4 files changed

+262
-120
lines changed

hgemm/README.md

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,18 @@
2626
- [X] hgemm_mma_m16n8k16_naive(MMA)
2727
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack)
2828
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
29-
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Reg Double Buffers)
29+
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Warp swizzle, Reg Double Buffers, Collective Store with Reg Reuse & Warp Shuffle)
3030
- [X] PyTorch bindings
3131

3232
## 目前性能
3333

3434
- NVIDIA L20
3535

36-
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)
36+
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle,[点击查看性能数据](#NV-L20)
3737

3838
- NVIDIA GeForce RTX 3080 Laptop
3939

40-
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 MMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,[点击查看性能数据](#NV-RTX-3080)
40+
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,[点击查看性能数据](#NV-RTX-3080)
4141

4242
## 共享内存 Bank Conflicts
4343

@@ -244,6 +244,108 @@ python3 hgemm.py --wmma --wmma-all # test all wmma kernels for all MNK
244244
python3 hgemm.py --mma --mma-all # test all mma kernels for all MNK
245245
```
246246

247+
示例1:
248+
```bash
249+
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma
250+
Namespace(M=16384, N=16384, K=8192, warmup=2, iters=10, show_all=False, enable_mma=True, enable_wmma=False, enable_cuda=False, enable_mma_all=False, enable_wmma_all=False, enable_cuda_all=False, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25)
251+
Loading hgemm lib ...
252+
pre allocate for fast profiling start, MAX_M=16384, MAX_N=16384, MAX_K=8192
253+
pre allocate for fast profiling done, time: 21829.665184020996 ms
254+
----------------------------------------------------------------------------------------------------------------------------------
255+
M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
256+
----------------------------------------------------------------------------------------------------------------------------------
257+
--------------------------------------------------------------------MMA-----------------------------------------------------------
258+
(mma2x4+warp4x4+stage3+swizzle): ['55.53125 ', '-34.4375 '], time:39.08894ms, swizzle: 4096, TFLOPS: 112.51(+0.00%)
259+
(mma2x4+warp4x4+stage2+swizzle): ['55.53125 ', '-34.4375 '], time:38.40720ms, swizzle: 4096, TFLOPS: 114.51(+1.78%)
260+
(mma2x4+warp4x4+stage3+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:39.23299ms, swizzle: 4096, TFLOPS: 112.10
261+
(mma2x4+warp4x4+stage2+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.20564ms, swizzle: 4096, TFLOPS: 115.12(+0.53%)
262+
(mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.67657ms, swizzle: 4096, TFLOPS: 113.71
263+
(mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:40.10882ms, swizzle: 4096, TFLOPS: 109.65
264+
(mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['55.53125 ', '-34.4375 '], time:38.44747ms, swizzle: 4096, TFLOPS: 114.39
265+
(cublas): ['55.53125 ', '-34.4375 '], time:37.43820ms, swizzle: NOOP, TFLOPS: 117.47(+2.05%)
266+
----------------------------------------------------------------------------------------------------------------------------------
267+
```
268+
示例2:
269+
```bash
270+
python3 hgemm.py --M 4096 --N 4096 --K 4096 --mma-all
271+
Namespace(M=4096, N=4096, K=4096, warmup=2, iters=10, show_all=False, enable_mma=False, enable_wmma=False, enable_cuda=False, enable_mma_all=True, enable_wmma_all=False, enable_cuda_all=False, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25)
272+
Loading hgemm lib ...
273+
pre allocate for fast profiling start, MAX_M=4096, MAX_N=4096, MAX_K=4096
274+
pre allocate for fast profiling done, time: 2056.009292602539 ms
275+
----------------------------------------------------------------------------------------------------------------------------------
276+
M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
277+
----------------------------------------------------------------------------------------------------------------------------------
278+
--------------------------------------------------------------------MMA-----------------------------------------------------------
279+
(mma2x4+warp4x4): ['131.625 ', '23.59375 '], time:1.412987ms, swizzle: NOOP, TFLOPS: 97.27 (+0.00%)
280+
(mma2x4+warp4x4+stage3): ['131.625 ', '23.59375 '], time:1.343512ms, swizzle: NOOP, TFLOPS: 102.30(+5.17%)
281+
(mma2x4+warp4x4+stage2): ['131.625 ', '23.59375 '], time:1.326799ms, swizzle: NOOP, TFLOPS: 103.59(+1.26%)
282+
(mma2x4+warp4x4+stage3+dsmem): ['131.625 ', '23.59375 '], time:1.350784ms, swizzle: NOOP, TFLOPS: 101.75
283+
(mma2x4+warp4x4+stage2+dsmem): ['131.625 ', '23.59375 '], time:1.326084ms, swizzle: NOOP, TFLOPS: 103.64(+0.05%)
284+
(mma2x4+warp4x4x2+stage4+dsmem): ['131.625 ', '23.59375 '], time:1.324439ms, swizzle: NOOP, TFLOPS: 103.77(+0.12%)
285+
(mma2x4+warp4x4x2+stage3+dsmem): ['131.625 ', '23.59375 '], time:1.369738ms, swizzle: NOOP, TFLOPS: 100.34
286+
(mma2x4+warp4x4x2+stage2+dsmem): ['131.625 ', '23.59375 '], time:1.299858ms, swizzle: NOOP, TFLOPS: 105.73(+1.89%)
287+
(mma2x4+warp4x4+stage3+swizzle): ['131.625 ', '23.59375 '], time:1.344513ms, swizzle: 1024, TFLOPS: 102.22
288+
(mma2x4+warp4x4+stage2+swizzle): ['131.625 ', '23.59375 '], time:1.324009ms, swizzle: 1024, TFLOPS: 103.81
289+
(mma2x4+warp4x4+stage3+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.349854ms, swizzle: 1024, TFLOPS: 101.82
290+
(mma2x4+warp4x4+stage2+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.318955ms, swizzle: 1024, TFLOPS: 104.20
291+
(mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.318430ms, swizzle: 1024, TFLOPS: 104.24
292+
(mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.371240ms, swizzle: 1024, TFLOPS: 100.23
293+
(mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['131.625 ', '23.59375 '], time:1.300096ms, swizzle: 1024, TFLOPS: 105.71
294+
(cublas): ['131.625 ', '23.59375 '], time:1.289224ms, swizzle: NOOP, TFLOPS: 106.61(+0.82%)
295+
----------------------------------------------------------------------------------------------------------------------------------
296+
```
297+
示例3:
298+
```bash
299+
python3 hgemm.py --M 4096 --N 4096 --K 4096 --mma-all --wmma-all --cuda-all
300+
Namespace(M=4096, N=4096, K=4096, warmup=2, iters=10, show_all=False, enable_mma=False, enable_wmma=False, enable_cuda=False, enable_mma_all=True, enable_wmma_all=True, enable_cuda_all=True, enable_torch=False, disable_cublas=False, sleep_duration=0.1, swizzle_factor=0.25)
301+
Loading hgemm lib ...
302+
pre allocate for fast profiling start, MAX_M=4096, MAX_N=4096, MAX_K=4096
303+
pre allocate for fast profiling done, time: 2048.0010509490967 ms
304+
----------------------------------------------------------------------------------------------------------------------------------
305+
M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
306+
----------------------------------------------------------------------------------------------------------------------------------
307+
(naive): ['-3.5371093', '-101.0 '], time:37.66887ms, swizzle: NOOP, TFLOPS: 3.65 (+0.00%)
308+
(f16x8pack+t8x8+bcf): ['-3.5371093', '-101.0 '], time:2.811360ms, swizzle: NOOP, TFLOPS: 48.89 (+1239.88%)
309+
(f16x8pack+t8x8+dbuf): ['-3.5371093', '-101.0 '], time:2.815437ms, swizzle: NOOP, TFLOPS: 48.82
310+
(f16x8pack+t8x8+k16+dbuf): ['-3.5371093', '-101.0 '], time:2.634835ms, swizzle: NOOP, TFLOPS: 52.16 (+6.70%)
311+
--------------------------------------------------------------------WMMA----------------------------------------------------------
312+
(mma4x2): ['-3.3847656', '-101.375 '], time:2.942705ms, swizzle: NOOP, TFLOPS: 46.70
313+
(mma4x2+warp2x4): ['-3.3847656', '-101.375 '], time:1.817488ms, swizzle: NOOP, TFLOPS: 75.62 (+44.97%)
314+
(mma4x2+warp2x4+stage3): ['-3.3847656', '-101.375 '], time:1.355123ms, swizzle: NOOP, TFLOPS: 101.42(+34.12%)
315+
(mma4x2+warp2x4+stage2): ['-3.3847656', '-101.375 '], time:1.343965ms, swizzle: NOOP, TFLOPS: 102.26(+0.83%)
316+
(mma4x2+warp2x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.342964ms, swizzle: NOOP, TFLOPS: 102.34(+0.07%)
317+
(mma4x2+warp2x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.343178ms, swizzle: NOOP, TFLOPS: 102.32
318+
(mma4x2+warp2x4+stage3+swizzle): ['-3.3847656', '-101.375 '], time:1.345729ms, swizzle: 1024, TFLOPS: 102.13
319+
(mma4x2+warp2x4+stage2+swizzle): ['-3.3847656', '-101.375 '], time:1.324367ms, swizzle: 1024, TFLOPS: 103.78(+1.40%)
320+
(mma4x2+warp2x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.351284ms, swizzle: 1024, TFLOPS: 101.71
321+
(mma4x2+warp2x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.324582ms, swizzle: 1024, TFLOPS: 103.76
322+
(mma4x4+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.387619ms, swizzle: NOOP, TFLOPS: 99.05
323+
(mma4x4+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.490569ms, swizzle: NOOP, TFLOPS: 92.21
324+
(mma4x2+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.376056ms, swizzle: NOOP, TFLOPS: 99.88
325+
(mma4x2+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.425576ms, swizzle: NOOP, TFLOPS: 96.41
326+
(mma4x4+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.395106ms, swizzle: 1024, TFLOPS: 98.52
327+
(mma4x4+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.414942ms, swizzle: 1024, TFLOPS: 97.13
328+
(mma4x2+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.377010ms, swizzle: 1024, TFLOPS: 99.81
329+
(mma4x2+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.604509ms, swizzle: 1024, TFLOPS: 85.66
330+
--------------------------------------------------------------------MMA-----------------------------------------------------------
331+
(mma2x4+warp4x4): ['-3.3847656', '-101.375 '], time:1.412653ms, swizzle: NOOP, TFLOPS: 97.29
332+
(mma2x4+warp4x4+stage3): ['-3.3847656', '-101.375 '], time:1.343774ms, swizzle: NOOP, TFLOPS: 102.28
333+
(mma2x4+warp4x4+stage2): ['-3.3847656', '-101.375 '], time:1.326417ms, swizzle: NOOP, TFLOPS: 103.62
334+
(mma2x4+warp4x4+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.351308ms, swizzle: NOOP, TFLOPS: 101.71
335+
(mma2x4+warp4x4+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.326489ms, swizzle: NOOP, TFLOPS: 103.61
336+
(mma2x4+warp4x4x2+stage4+dsmem): ['-3.3847656', '-101.375 '], time:1.324319ms, swizzle: NOOP, TFLOPS: 103.78(+0.00%)
337+
(mma2x4+warp4x4x2+stage3+dsmem): ['-3.3847656', '-101.375 '], time:1.369786ms, swizzle: NOOP, TFLOPS: 100.34
338+
(mma2x4+warp4x4x2+stage2+dsmem): ['-3.3847656', '-101.375 '], time:1.299762ms, swizzle: NOOP, TFLOPS: 105.74(+1.89%)
339+
(mma2x4+warp4x4+stage3+swizzle): ['-3.3847656', '-101.375 '], time:1.344013ms, swizzle: 1024, TFLOPS: 102.26
340+
(mma2x4+warp4x4+stage2+swizzle): ['-3.3847656', '-101.375 '], time:1.324701ms, swizzle: 1024, TFLOPS: 103.75
341+
(mma2x4+warp4x4+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.348972ms, swizzle: 1024, TFLOPS: 101.88
342+
(mma2x4+warp4x4+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.318597ms, swizzle: 1024, TFLOPS: 104.23
343+
(mma2x4+warp4x4x2+stage4+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.318240ms, swizzle: 1024, TFLOPS: 104.26
344+
(mma2x4+warp4x4x2+stage3+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.370477ms, swizzle: 1024, TFLOPS: 100.29
345+
(mma2x4+warp4x4x2+stage2+dsmem+swizzle): ['-3.3847656', '-101.375 '], time:1.300477ms, swizzle: 1024, TFLOPS: 105.68
346+
(cublas): ['-3.3847656', '-101.375 '], time:1.289367ms, swizzle: NOOP, TFLOPS: 106.59(+0.81%)
347+
----------------------------------------------------------------------------------------------------------------------------------
348+
```
247349

248350
## NVIDIA L20
249351
<div id="NV-L20"></div>

0 commit comments

Comments
 (0)