Skip to content

Commit ed1d100

Browse files
authored
[HGEMM] Update HGEMM L20/4090 Bench (#137)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent 2d08428 commit ed1d100

File tree

2 files changed

+31
-61
lines changed

2 files changed

+31
-61
lines changed

README.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
<div id="hgemm-sgemm"></div>
2222

2323
<div align='left'>
24-
<img src='https://github.com/user-attachments/assets/a0039200-cd9e-4ae6-be13-422fff75dd2b' height="225px" width="403px">
25-
<img src='https://github.com/user-attachments/assets/c7d65fe5-9fb9-49a8-b962-a6c09bcc030a' height="225px" width="403px">
24+
<img src='https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91' height="225px" width="403px">
25+
<img src='https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c' height="225px" width="403px">
2626
</div>
2727

2828
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
@@ -42,6 +42,24 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's d
4242

4343

4444
<!---
45+
![NVIDIA_L20_NN+TN](https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91)
46+
![NVIDIA_GeForce_RTX_4090_NN+TN](https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c)
47+
48+
<div align='left'>
49+
<img src='https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91' width="805px">
50+
</div>
51+
52+
<div align='left'>
53+
<img src='https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91' height="225px" width="403px">
54+
<img src='https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c' height="225px" width="403px">
55+
</div>
56+
57+
<div align='left'>
58+
<img src='https://github.com/user-attachments/assets/a0039200-cd9e-4ae6-be13-422fff75dd2b' height="225px" width="403px">
59+
<img src='https://github.com/user-attachments/assets/c7d65fe5-9fb9-49a8-b962-a6c09bcc030a' height="225px" width="403px">
60+
</div>
61+
62+
4563
![cuda-learn-notes](https://github.com/DefTruth/CUDA-Learn-Note/assets/31974251/882271fe-ab60-4b0e-9440-2e0fa3c0fb6f)
4664
![cuda-learn-notes](https://github.com/user-attachments/assets/b2578723-b7a7-4d8f-bcd1-5008947b808a)
4765
![L20](https://github.com/user-attachments/assets/a0039200-cd9e-4ae6-be13-422fff75dd2b)

hgemm/README.md

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
git submodule update --init --recursive --force
5353
```
5454

55-
**Python**: 支持python脚本直接测试
55+
**Python**: 支持Python脚本直接测试
5656

5757
```bash
5858
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
@@ -75,7 +75,7 @@ python3 hgemm.py --mma-all --plot --topk 8
7575
python3 hgemm.py --cute-tn --mma --plot
7676
```
7777

78-
**C++**: C++测试目前仅支持CuTe HGEMM和cuBLAS HGEMM,C++ bin方式测试的性能数据会略优于python测试方式,可能是torch binding引入了一定的开销
78+
**C++**: HGEMM benchmark也支持C++测试,但目前仅支持本仓库实现的CuTe HGEMM TN 和cuBLAS HGEMM TN 进行对比,C++ bin方式测试的性能数据会略优于Python测试方式,可能是PyTorch Python binding引入了一定的额外开销
7979
```bash
8080
make
8181
./hgemm_cute.bin
@@ -132,16 +132,18 @@ M N K = 12800 12800 12800, Time = 0.03695514 0.03705610 0.03711386 s, A
132132

133133
### NVIDIA L20
134134

135-
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概99%左右的性能。使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。CuTe版本的HGEMM性能基本持平cuBLAS,部分case会超越cuBLAS,能达到116-117 TFLOPS。目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts。
135+
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概99%左右的性能。使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。CuTe版本的HGEMM性能基本持平cuBLAS,部分case会超越cuBLAS,能达到 116-117 TFLOPS。目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。
136136

137137
<div id="NV-L20"></div>
138138

139139

140140
<!---
141141
![L20](https://github.com/user-attachments/assets/a0039200-cd9e-4ae6-be13-422fff75dd2b)
142+
![L20](./NVIDIA_L20.png)
143+
142144
--->
145+
![NVIDIA_L20_NN+TN](https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91)
143146

144-
![L20](./NVIDIA_L20.png)
145147

146148
- WMMA: Up to 113.76 TFLOPS, 113.83/119.5=95.25% TFLOPS utilization, 113.83/116.25=97.91% cuBLAS performance.
147149
- MMA: Up to 115.12 TFLOPS, 115.12/119.5=96.33% TFLOPS utilization, 115.12/116.25=99.03% cuBLAS performance.
@@ -168,7 +170,7 @@ python3 hgemm.py --M 16384 --N 16384 --K 8192 --mma-all --wmma-all --cuda-all
168170
```
169171
全量MNK测试命令(提示: 每个MNK单独测试的性能数据更准确)
170172
```bash
171-
python3 hgemm.py --mma-all --wmma-all --cuda-all
173+
python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
172174
```
173175

174176
### NVIDIA GeForce RTX 4090
@@ -178,51 +180,10 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
178180
![4090](https://github.com/user-attachments/assets/c7d65fe5-9fb9-49a8-b962-a6c09bcc030a)
179181
--->
180182

181-
![4090](./NVIDIA_GeForce_RTX_4090.png)
183+
![NVIDIA_GeForce_RTX_4090_NN+TN](https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c)
182184

183185
```bash
184-
----------------------------------------------------------------------------------------------------------------------------------
185-
M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
186-
----------------------------------------------------------------------------------------------------------------------------------
187-
--------------------------------------------------------------------WMMA----------------------------------------------------------
188-
(wmma4x2): ['-137.375 ', '53.65625 '], time:90.05668ms, swizzle: NOOP, TFLOPS: 48.84 (+0.00%)
189-
(wmma4x2+warp2x4): ['-137.375 ', '53.65625 '], time:37.53635ms, swizzle: NOOP, TFLOPS: 117.17(+139.92%)
190-
(wmma4x2+warp2x4+stage3): ['-137.375 ', '53.65625 '], time:25.96564ms, swizzle: NOOP, TFLOPS: 169.38(+44.56%)
191-
(wmma4x2+warp2x4+stage2): ['-137.375 ', '53.65625 '], time:25.21226ms, swizzle: NOOP, TFLOPS: 174.44(+2.99%)
192-
(wmma4x2+warp2x4+stage3+swizzle): ['-137.375 ', '53.65625 '], time:22.99013ms, swizzle: 4096, TFLOPS: 191.30(+9.67%)
193-
(wmma4x2+warp2x4+stage2+swizzle): ['-137.375 ', '53.65625 '], time:22.91676ms, swizzle: 4096, TFLOPS: 191.91(+0.32%)
194-
(wmma4x2+warp2x4+stage2+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:22.78118ms, swizzle: 4096, TFLOPS: 193.06(+0.60%)
195-
(wmma4x4+warp4x4+stage3+dsmem): ['-137.375 ', '53.65625 '], time:18.66145ms, swizzle: NOOP, TFLOPS: 235.68(+22.08%)
196-
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:18.16847ms, swizzle: 4096, TFLOPS: 242.07(+2.71%)
197-
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['-137.375 ', '53.65625 '], time:18.11864ms, swizzle: 4096, TFLOPS: 242.74(+0.28%)
198-
(cublas): ['-137.375 ', '53.65625 '], time:18.07777ms, swizzle: NOOP, TFLOPS: 243.28(+0.23%)
199-
----------------------------------------------------------------------------------------------------------------------------------
200-
----------------------------------------------------------------------------------------------------------------------------------
201-
M=8192, N=8192, K=8192, Warmup=2, Iters=10, 1/1
202-
----------------------------------------------------------------------------------------------------------------------------------
203-
--------------------------------------------------------------------WMMA----------------------------------------------------------
204-
(wmma4x2): ['11.453125 ', '-1.0664062'], time:18.48518ms, swizzle: NOOP, TFLOPS: 59.48 (+0.00%)
205-
(wmma4x2+warp2x4): ['11.453125 ', '-1.0664062'], time:9.354352ms, swizzle: NOOP, TFLOPS: 117.54(+97.61%)
206-
(wmma4x2+warp2x4+stage3): ['11.453125 ', '-1.0664062'], time:5.835342ms, swizzle: NOOP, TFLOPS: 188.42(+60.31%)
207-
(wmma4x2+warp2x4+stage2): ['11.453125 ', '-1.0664062'], time:5.795311ms, swizzle: NOOP, TFLOPS: 189.72(+0.69%)
208-
(wmma4x2+warp2x4+stage3+dsmem): ['11.453125 ', '-1.0664062'], time:5.795168ms, swizzle: NOOP, TFLOPS: 189.73(+0.00%)
209-
(wmma4x2+warp2x4+stage3+swizzle): ['11.453125 ', '-1.0664062'], time:5.384325ms, swizzle: 2048, TFLOPS: 204.21(+7.63%)
210-
(wmma4x4+warp4x4+stage3+dsmem): ['11.453125 ', '-1.0664062'], time:4.254937ms, swizzle: NOOP, TFLOPS: 258.41(+26.54%)
211-
(cublas): ['11.421875 ', '-1.3203125'], time:4.288864ms, swizzle: NOOP, TFLOPS: 256.36
212-
----------------------------------------------------------------------------------------------------------------------------------
213-
----------------------------------------------------------------------------------------------------------------------------------
214-
M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
215-
----------------------------------------------------------------------------------------------------------------------------------
216-
--------------------------------------------------------------------WMMA----------------------------------------------------------
217-
(wmma4x2): ['-9.0 ', '-144.875 '], time:2.341437ms, swizzle: NOOP, TFLOPS: 58.70 (+0.00%)
218-
(wmma4x2+warp2x4): ['-9.0 ', '-144.875 '], time:1.237440ms, swizzle: NOOP, TFLOPS: 111.07(+89.22%)
219-
(wmma4x2+warp2x4+stage3): ['-9.0 ', '-144.875 '], time:0.725293ms, swizzle: NOOP, TFLOPS: 189.49(+70.61%)
220-
(wmma4x2+warp2x4+stage3+dsmem): ['-9.0 ', '-144.875 '], time:0.723266ms, swizzle: NOOP, TFLOPS: 190.03(+0.28%)
221-
(wmma4x2+warp2x4+stage3+swizzle): ['-9.0 ', '-144.875 '], time:0.702548ms, swizzle: 2048, TFLOPS: 195.63(+2.95%)
222-
(wmma4x2+warp2x4+stage3+dsmem+swizzle): ['-9.0 ', '-144.875 '], time:0.702190ms, swizzle: 2048, TFLOPS: 195.73(+0.05%)
223-
(wmma4x4+warp4x4+stage3+dsmem): ['-9.0 ', '-144.875 '], time:0.556564ms, swizzle: NOOP, TFLOPS: 246.94(+26.17%)
224-
(cublas): ['-9.0 ', '-144.875 '], time:0.539851ms, swizzle: NOOP, TFLOPS: 254.59(+3.10%)
225-
----------------------------------------------------------------------------------------------------------------------------------
186+
python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
226187
```
227188

228189
### NVIDIA GeForce RTX 3080 Laptop
@@ -232,16 +193,7 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
232193
![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)
233194

234195
```bash
235-
python3 hgemm.py --wmma-all
236-
----------------------------------------------------------------------------------------------------------------------------------
237-
M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
238-
----------------------------------------------------------------------------------------------------------------------------------
239-
(wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
240-
(wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
241-
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
242-
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
243-
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
244-
----------------------------------------------------------------------------------------------------------------------------------
196+
python3 hgemm.py --wmma-all --plot --dir tmp
245197
```
246198

247199

@@ -291,7 +243,7 @@ NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/
291243
```C
292244
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
293245
```
294-
目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts。
246+
目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。
295247
296248
### 双缓冲 Double Buffers
297249

0 commit comments

Comments
 (0)