52
52
git submodule update --init --recursive --force
53
53
```
54
54
55
- ** Python** : 支持python脚本直接测试
55
+ ** Python** : 支持Python脚本直接测试
56
56
57
57
``` bash
58
58
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
@@ -75,7 +75,7 @@ python3 hgemm.py --mma-all --plot --topk 8
75
75
python3 hgemm.py --cute-tn --mma --plot
76
76
```
77
77
78
- ** C++** : C++测试目前仅支持CuTe HGEMM和cuBLAS HGEMM,C++ bin方式测试的性能数据会略优于python测试方式,可能是torch binding引入了一定的开销 。
78
+ ** C++** : HGEMM benchmark也支持C++测试,但目前仅支持本仓库实现的CuTe HGEMM TN 和cuBLAS HGEMM TN 进行对比 ,C++ bin方式测试的性能数据会略优于Python测试方式,可能是PyTorch Python binding引入了一定的额外开销 。
79
79
``` bash
80
80
make
81
81
./hgemm_cute.bin
@@ -132,16 +132,18 @@ M N K = 12800 12800 12800, Time = 0.03695514 0.03705610 0.03711386 s, A
132
132
133
133
### NVIDIA L20
134
134
135
- 目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概99%左右的性能。使用WMMA API能达到cuBLAS大概95%~ 98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。CuTe版本的HGEMM性能基本持平cuBLAS,部分case会超越cuBLAS,能达到116 -117 TFLOPS。目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts。
135
+ 目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),整体上能达到cuBLAS大概99%左右的性能。使用WMMA API能达到cuBLAS大概95%~ 98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。CuTe版本的HGEMM性能基本持平cuBLAS,部分case会超越cuBLAS,能达到 116 -117 TFLOPS。目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。
136
136
137
137
<div id =" NV-L20 " ></div >
138
138
139
139
140
140
<!-- -
141
141

142
+ 
143
+
142
144
--->
145
+ ![ NVIDIA_L20_NN+TN] ( https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91 )
143
146
144
- ![ L20] ( ./NVIDIA_L20.png )
145
147
146
148
- WMMA: Up to 113.76 TFLOPS, 113.83/119.5=95.25% TFLOPS utilization, 113.83/116.25=97.91% cuBLAS performance.
147
149
- MMA: Up to 115.12 TFLOPS, 115.12/119.5=96.33% TFLOPS utilization, 115.12/116.25=99.03% cuBLAS performance.
@@ -168,7 +170,7 @@ python3 hgemm.py --M 16384 --N 16384 --K 8192 --mma-all --wmma-all --cuda-all
168
170
```
169
171
全量MNK测试命令(提示: 每个MNK单独测试的性能数据更准确)
170
172
``` bash
171
- python3 hgemm.py --mma-all --wmma-all --cuda -all
173
+ python3 hgemm.py --cute-tn -- mma --plot --dir tmp --tag NN+TN --i 20 --wmma -all
172
174
```
173
175
174
176
### NVIDIA GeForce RTX 4090
@@ -178,51 +180,10 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
178
180

179
181
--->
180
182
181
- ![ 4090 ] ( ./NVIDIA_GeForce_RTX_4090.png )
183
+ ![ NVIDIA_GeForce_RTX_4090_NN+TN ] ( https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c )
182
184
183
185
``` bash
184
- ----------------------------------------------------------------------------------------------------------------------------------
185
- M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
186
- ----------------------------------------------------------------------------------------------------------------------------------
187
- --------------------------------------------------------------------WMMA----------------------------------------------------------
188
- (wmma4x2): [' -137.375 ' , ' 53.65625 ' ], time:90.05668ms, swizzle: NOOP, TFLOPS: 48.84 (+0.00%)
189
- (wmma4x2+warp2x4): [' -137.375 ' , ' 53.65625 ' ], time:37.53635ms, swizzle: NOOP, TFLOPS: 117.17(+139.92%)
190
- (wmma4x2+warp2x4+stage3): [' -137.375 ' , ' 53.65625 ' ], time:25.96564ms, swizzle: NOOP, TFLOPS: 169.38(+44.56%)
191
- (wmma4x2+warp2x4+stage2): [' -137.375 ' , ' 53.65625 ' ], time:25.21226ms, swizzle: NOOP, TFLOPS: 174.44(+2.99%)
192
- (wmma4x2+warp2x4+stage3+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.99013ms, swizzle: 4096, TFLOPS: 191.30(+9.67%)
193
- (wmma4x2+warp2x4+stage2+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.91676ms, swizzle: 4096, TFLOPS: 191.91(+0.32%)
194
- (wmma4x2+warp2x4+stage2+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.78118ms, swizzle: 4096, TFLOPS: 193.06(+0.60%)
195
- (wmma4x4+warp4x4+stage3+dsmem): [' -137.375 ' , ' 53.65625 ' ], time:18.66145ms, swizzle: NOOP, TFLOPS: 235.68(+22.08%)
196
- (wmma4x4+warp4x4+stage3+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:18.16847ms, swizzle: 4096, TFLOPS: 242.07(+2.71%)
197
- (wmma4x4+warp4x4+stage2+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:18.11864ms, swizzle: 4096, TFLOPS: 242.74(+0.28%)
198
- (cublas): [' -137.375 ' , ' 53.65625 ' ], time:18.07777ms, swizzle: NOOP, TFLOPS: 243.28(+0.23%)
199
- ----------------------------------------------------------------------------------------------------------------------------------
200
- ----------------------------------------------------------------------------------------------------------------------------------
201
- M=8192, N=8192, K=8192, Warmup=2, Iters=10, 1/1
202
- ----------------------------------------------------------------------------------------------------------------------------------
203
- --------------------------------------------------------------------WMMA----------------------------------------------------------
204
- (wmma4x2): [' 11.453125 ' , ' -1.0664062' ], time:18.48518ms, swizzle: NOOP, TFLOPS: 59.48 (+0.00%)
205
- (wmma4x2+warp2x4): [' 11.453125 ' , ' -1.0664062' ], time:9.354352ms, swizzle: NOOP, TFLOPS: 117.54(+97.61%)
206
- (wmma4x2+warp2x4+stage3): [' 11.453125 ' , ' -1.0664062' ], time:5.835342ms, swizzle: NOOP, TFLOPS: 188.42(+60.31%)
207
- (wmma4x2+warp2x4+stage2): [' 11.453125 ' , ' -1.0664062' ], time:5.795311ms, swizzle: NOOP, TFLOPS: 189.72(+0.69%)
208
- (wmma4x2+warp2x4+stage3+dsmem): [' 11.453125 ' , ' -1.0664062' ], time:5.795168ms, swizzle: NOOP, TFLOPS: 189.73(+0.00%)
209
- (wmma4x2+warp2x4+stage3+swizzle): [' 11.453125 ' , ' -1.0664062' ], time:5.384325ms, swizzle: 2048, TFLOPS: 204.21(+7.63%)
210
- (wmma4x4+warp4x4+stage3+dsmem): [' 11.453125 ' , ' -1.0664062' ], time:4.254937ms, swizzle: NOOP, TFLOPS: 258.41(+26.54%)
211
- (cublas): [' 11.421875 ' , ' -1.3203125' ], time:4.288864ms, swizzle: NOOP, TFLOPS: 256.36
212
- ----------------------------------------------------------------------------------------------------------------------------------
213
- ----------------------------------------------------------------------------------------------------------------------------------
214
- M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
215
- ----------------------------------------------------------------------------------------------------------------------------------
216
- --------------------------------------------------------------------WMMA----------------------------------------------------------
217
- (wmma4x2): [' -9.0 ' , ' -144.875 ' ], time:2.341437ms, swizzle: NOOP, TFLOPS: 58.70 (+0.00%)
218
- (wmma4x2+warp2x4): [' -9.0 ' , ' -144.875 ' ], time:1.237440ms, swizzle: NOOP, TFLOPS: 111.07(+89.22%)
219
- (wmma4x2+warp2x4+stage3): [' -9.0 ' , ' -144.875 ' ], time:0.725293ms, swizzle: NOOP, TFLOPS: 189.49(+70.61%)
220
- (wmma4x2+warp2x4+stage3+dsmem): [' -9.0 ' , ' -144.875 ' ], time:0.723266ms, swizzle: NOOP, TFLOPS: 190.03(+0.28%)
221
- (wmma4x2+warp2x4+stage3+swizzle): [' -9.0 ' , ' -144.875 ' ], time:0.702548ms, swizzle: 2048, TFLOPS: 195.63(+2.95%)
222
- (wmma4x2+warp2x4+stage3+dsmem+swizzle): [' -9.0 ' , ' -144.875 ' ], time:0.702190ms, swizzle: 2048, TFLOPS: 195.73(+0.05%)
223
- (wmma4x4+warp4x4+stage3+dsmem): [' -9.0 ' , ' -144.875 ' ], time:0.556564ms, swizzle: NOOP, TFLOPS: 246.94(+26.17%)
224
- (cublas): [' -9.0 ' , ' -144.875 ' ], time:0.539851ms, swizzle: NOOP, TFLOPS: 254.59(+3.10%)
225
- ----------------------------------------------------------------------------------------------------------------------------------
186
+ python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
226
187
```
227
188
228
189
### NVIDIA GeForce RTX 3080 Laptop
@@ -232,16 +193,7 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
232
193
![ ] ( ./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png )
233
194
234
195
``` bash
235
- python3 hgemm.py --wmma-all
236
- ----------------------------------------------------------------------------------------------------------------------------------
237
- M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
238
- ----------------------------------------------------------------------------------------------------------------------------------
239
- (wmma4x4+warp4x4+stage3+dsmem): [' 68.375 ' , ' -2.234375 ' ], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
240
- (wmma4x4+warp4x4+stage2+dsmem): [' 68.375 ' , ' -2.234375 ' ], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
241
- (wmma4x4+warp4x4+stage3+dsmem+swizzle): [' 68.375 ' , ' -2.234375 ' ], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
242
- (wmma4x4+warp4x4+stage2+dsmem+swizzle): [' 68.375 ' , ' -2.234375 ' ], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
243
- (cublas): [' 68.375 ' , ' -2.234375 ' ], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
244
- ----------------------------------------------------------------------------------------------------------------------------------
196
+ python3 hgemm.py --wmma-all --plot --dir tmp
245
197
```
246
198
247
199
@@ -291,7 +243,7 @@ NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/
291
243
``` C
292
244
cudaDeviceSetSharedMemConfig (cudaSharedMemBankSizeEightByte);
293
245
```
294
- 目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts。
246
+ 目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。
295
247
296
248
### 双缓冲 Double Buffers
297
249
0 commit comments