Skip to content

Commit 28c12bd

Browse files
authored
[HGEMM] Add NVIDIA RTX 4090 benchmark (#119)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 2167679 commit 28c12bd

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
|✔️|✔️|✔️|✔️|
2525
|**Reg Double Buffers**|**Block Swizzle**|**Warp Swizzle**|**Collective Store(Shfl)**|
2626
|✔️|✔️|✔️|✔️|
27-
|**Row Major(NN)**|**Col Major(TN)**|**SGEMM TF32**|**SMEM Swizzle/Permuted**|
27+
|**Row Major(NN)**|**Col Major(TN)**|**SGEMM TF32**|**SMEM Swizzle(Permute)**|
2828
|✔️|✔️|✔️||
2929

30+
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA and MMA)` implemented in this repo can achieve approximately `95%~99%` of its performance. Please check [hgemm benchmark](./hgemm) for more details.
31+
3032
## 📖 CUDA Kernel目录 (面试常考题目)
3133
- / = not supported now.
3234
- ✔️ = known work and already supported now.

hgemm/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,53 @@ python3 hgemm.py --M 16384 --N 16384 --K 8192 --mma-all --wmma-all --cuda-all
8080
python3 hgemm.py --mma-all --wmma-all --cuda-all
8181
```
8282

83+
### NVIDIA GeForce RTX 4090
84+
在NVIDIA RTX 4090上(FP16 Tensor Cores算力为330 TFLOPS),WMMA(m16n16k16)性能表现比MMA(m16n8k16)要更好,大分部MNK下,本仓库的实现能达到cuBLAS 95%~99%的性能,某些case能超过cuBLAS。就本仓库的实现而言,在RTX 4090上,大规模矩阵乘(MNK>=8192),WMMA表现更优,小规模矩阵乘,MMA表现更优。
85+
```bash
86+
----------------------------------------------------------------------------------------------------------------------------------
87+
M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
88+
----------------------------------------------------------------------------------------------------------------------------------
89+
--------------------------------------------------------------------WMMA----------------------------------------------------------
90+
(wmma4x2): ['-137.375 ', '53.65625 '], time:90.05668ms, swizzle: NOOP, TFLOPS: 48.84 (+0.00%)
91+
(wmma4x2+warp2x4): ['-137.375 ', '53.65625 '], time:37.53635ms, swizzle: NOOP, TFLOPS: 117.17(+139.92%)
92+
(wmma4x2+warp2x4+stage3): ['-137.375 ', '53.65625 '], time:25.96564ms, swizzle: NOOP, TFLOPS: 169.38(+44.56%)
93+
(wmma4x2+warp2x4+stage2): ['-137.375 ', '53.65625 '], time:25.21226ms, swizzle: NOOP, TFLOPS: 174.44(+2.99%)
94+
(wmma4x2+warp2x4+stage3+swizzle): ['-137.375 ', '53.65625 '], time:22.99013ms, swizzle: 4096, TFLOPS: 191.30(+9.67%)
95+
(wmma4x2+warp2x4+stage2+swizzle): ['-137.375 ', '53.65625 '], time:22.91676ms, swizzle: 4096, TFLOPS: 191.91(+0.32%)
96+
(wmma4x2+warp2x4+stage2+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:22.78118ms, swizzle: 4096, TFLOPS: 193.06(+0.60%)
97+
(wmma4x4+warp4x4+stage3+dsmem): ['-137.375 ', '53.65625 '], time:18.66145ms, swizzle: NOOP, TFLOPS: 235.68(+22.08%)
98+
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:18.16847ms, swizzle: 4096, TFLOPS: 242.07(+2.71%)
99+
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:18.11864ms, swizzle: 4096, TFLOPS: 242.74(+0.28%)
100+
(cublas): ['-137.375 ', '53.65625 '], time:18.07777ms, swizzle: NOOP, TFLOPS: 243.28(+0.23%)
101+
----------------------------------------------------------------------------------------------------------------------------------
102+
----------------------------------------------------------------------------------------------------------------------------------
103+
M=8192, N=8192, K=8192, Warmup=2, Iters=10, 1/1
104+
----------------------------------------------------------------------------------------------------------------------------------
105+
--------------------------------------------------------------------WMMA----------------------------------------------------------
106+
(wmma4x2): ['11.453125 ', '-1.0664062'], time:18.48518ms, swizzle: NOOP, TFLOPS: 59.48 (+0.00%)
107+
(wmma4x2+warp2x4): ['11.453125 ', '-1.0664062'], time:9.354352ms, swizzle: NOOP, TFLOPS: 117.54(+97.61%)
108+
(wmma4x2+warp2x4+stage3): ['11.453125 ', '-1.0664062'], time:5.835342ms, swizzle: NOOP, TFLOPS: 188.42(+60.31%)
109+
(wmma4x2+warp2x4+stage2): ['11.453125 ', '-1.0664062'], time:5.795311ms, swizzle: NOOP, TFLOPS: 189.72(+0.69%)
110+
(wmma4x2+warp2x4+stage3+dsmem): ['11.453125 ', '-1.0664062'], time:5.795168ms, swizzle: NOOP, TFLOPS: 189.73(+0.00%)
111+
(wmma4x2+warp2x4+stage3+swizzle): ['11.453125 ', '-1.0664062'], time:5.384325ms, swizzle: 2048, TFLOPS: 204.21(+7.63%)
112+
(wmma4x4+warp4x4+stage3+dsmem): ['11.453125 ', '-1.0664062'], time:4.254937ms, swizzle: NOOP, TFLOPS: 258.41(+26.54%)
113+
(cublas): ['11.421875 ', '-1.3203125'], time:4.288864ms, swizzle: NOOP, TFLOPS: 256.36
114+
----------------------------------------------------------------------------------------------------------------------------------
115+
----------------------------------------------------------------------------------------------------------------------------------
116+
M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
117+
----------------------------------------------------------------------------------------------------------------------------------
118+
--------------------------------------------------------------------WMMA----------------------------------------------------------
119+
(wmma4x2): ['-9.0 ', '-144.875 '], time:2.341437ms, swizzle: NOOP, TFLOPS: 58.70 (+0.00%)
120+
(wmma4x2+warp2x4): ['-9.0 ', '-144.875 '], time:1.237440ms, swizzle: NOOP, TFLOPS: 111.07(+89.22%)
121+
(wmma4x2+warp2x4+stage3): ['-9.0 ', '-144.875 '], time:0.725293ms, swizzle: NOOP, TFLOPS: 189.49(+70.61%)
122+
(wmma4x2+warp2x4+stage3+dsmem): ['-9.0 ', '-144.875 '], time:0.723266ms, swizzle: NOOP, TFLOPS: 190.03(+0.28%)
123+
(wmma4x2+warp2x4+stage3+swizzle): ['-9.0 ', '-144.875 '], time:0.702548ms, swizzle: 2048, TFLOPS: 195.63(+2.95%)
124+
(wmma4x2+warp2x4+stage3+dsmem+swizzle): ['-9.0 ', '-144.875 '], time:0.702190ms, swizzle: 2048, TFLOPS: 195.73(+0.05%)
125+
(wmma4x4+warp4x4+stage3+dsmem): ['-9.0 ', '-144.875 '], time:0.556564ms, swizzle: NOOP, TFLOPS: 246.94(+26.17%)
126+
(cublas): ['-9.0 ', '-144.875 '], time:0.539851ms, swizzle: NOOP, TFLOPS: 254.59(+3.10%)
127+
----------------------------------------------------------------------------------------------------------------------------------
128+
```
129+
83130
### NVIDIA GeForce RTX 3080 Laptop
84131

85132
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,不过Laptop测试的性能数据不稳定,这部分看看就好,别太当真。
@@ -96,6 +143,7 @@ python3 hgemm.py --wmma-all
96143
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
97144
----------------------------------------------------------------------------------------------------------------------------------
98145
```
146+
99147
## 测试命令
100148

101149
```bash

0 commit comments

Comments
 (0)