@@ -80,6 +80,53 @@ python3 hgemm.py --M 16384 --N 16384 --K 8192 --mma-all --wmma-all --cuda-all
80
80
python3 hgemm.py --mma-all --wmma-all --cuda-all
81
81
```
82
82
83
+ ### NVIDIA GeForce RTX 4090
84
+ 在NVIDIA RTX 4090上(FP16 Tensor Cores算力为330 TFLOPS),WMMA(m16n16k16)性能表现比MMA(m16n8k16)要更好,大分部MNK下,本仓库的实现能达到cuBLAS 95%~ 99%的性能,某些case能超过cuBLAS。就本仓库的实现而言,在RTX 4090上,大规模矩阵乘(MNK>=8192),WMMA表现更优,小规模矩阵乘,MMA表现更优。
85
+ ``` bash
86
+ ----------------------------------------------------------------------------------------------------------------------------------
87
+ M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
88
+ ----------------------------------------------------------------------------------------------------------------------------------
89
+ --------------------------------------------------------------------WMMA----------------------------------------------------------
90
+ (wmma4x2): [' -137.375 ' , ' 53.65625 ' ], time:90.05668ms, swizzle: NOOP, TFLOPS: 48.84 (+0.00%)
91
+ (wmma4x2+warp2x4): [' -137.375 ' , ' 53.65625 ' ], time:37.53635ms, swizzle: NOOP, TFLOPS: 117.17(+139.92%)
92
+ (wmma4x2+warp2x4+stage3): [' -137.375 ' , ' 53.65625 ' ], time:25.96564ms, swizzle: NOOP, TFLOPS: 169.38(+44.56%)
93
+ (wmma4x2+warp2x4+stage2): [' -137.375 ' , ' 53.65625 ' ], time:25.21226ms, swizzle: NOOP, TFLOPS: 174.44(+2.99%)
94
+ (wmma4x2+warp2x4+stage3+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.99013ms, swizzle: 4096, TFLOPS: 191.30(+9.67%)
95
+ (wmma4x2+warp2x4+stage2+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.91676ms, swizzle: 4096, TFLOPS: 191.91(+0.32%)
96
+ (wmma4x2+warp2x4+stage2+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:22.78118ms, swizzle: 4096, TFLOPS: 193.06(+0.60%)
97
+ (wmma4x4+warp4x4+stage3+dsmem): [' -137.375 ' , ' 53.65625 ' ], time:18.66145ms, swizzle: NOOP, TFLOPS: 235.68(+22.08%)
98
+ (wmma4x4+warp4x4+stage3+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:18.16847ms, swizzle: 4096, TFLOPS: 242.07(+2.71%)
99
+ (wmma4x4+warp4x4+stage2+dsmem+swizzle): [' -137.375 ' , ' 53.65625 ' ], time:18.11864ms, swizzle: 4096, TFLOPS: 242.74(+0.28%)
100
+ (cublas): [' -137.375 ' , ' 53.65625 ' ], time:18.07777ms, swizzle: NOOP, TFLOPS: 243.28(+0.23%)
101
+ ----------------------------------------------------------------------------------------------------------------------------------
102
+ ----------------------------------------------------------------------------------------------------------------------------------
103
+ M=8192, N=8192, K=8192, Warmup=2, Iters=10, 1/1
104
+ ----------------------------------------------------------------------------------------------------------------------------------
105
+ --------------------------------------------------------------------WMMA----------------------------------------------------------
106
+ (wmma4x2): [' 11.453125 ' , ' -1.0664062' ], time:18.48518ms, swizzle: NOOP, TFLOPS: 59.48 (+0.00%)
107
+ (wmma4x2+warp2x4): [' 11.453125 ' , ' -1.0664062' ], time:9.354352ms, swizzle: NOOP, TFLOPS: 117.54(+97.61%)
108
+ (wmma4x2+warp2x4+stage3): [' 11.453125 ' , ' -1.0664062' ], time:5.835342ms, swizzle: NOOP, TFLOPS: 188.42(+60.31%)
109
+ (wmma4x2+warp2x4+stage2): [' 11.453125 ' , ' -1.0664062' ], time:5.795311ms, swizzle: NOOP, TFLOPS: 189.72(+0.69%)
110
+ (wmma4x2+warp2x4+stage3+dsmem): [' 11.453125 ' , ' -1.0664062' ], time:5.795168ms, swizzle: NOOP, TFLOPS: 189.73(+0.00%)
111
+ (wmma4x2+warp2x4+stage3+swizzle): [' 11.453125 ' , ' -1.0664062' ], time:5.384325ms, swizzle: 2048, TFLOPS: 204.21(+7.63%)
112
+ (wmma4x4+warp4x4+stage3+dsmem): [' 11.453125 ' , ' -1.0664062' ], time:4.254937ms, swizzle: NOOP, TFLOPS: 258.41(+26.54%)
113
+ (cublas): [' 11.421875 ' , ' -1.3203125' ], time:4.288864ms, swizzle: NOOP, TFLOPS: 256.36
114
+ ----------------------------------------------------------------------------------------------------------------------------------
115
+ ----------------------------------------------------------------------------------------------------------------------------------
116
+ M=4096, N=4096, K=4096, Warmup=2, Iters=10, 1/1
117
+ ----------------------------------------------------------------------------------------------------------------------------------
118
+ --------------------------------------------------------------------WMMA----------------------------------------------------------
119
+ (wmma4x2): [' -9.0 ' , ' -144.875 ' ], time:2.341437ms, swizzle: NOOP, TFLOPS: 58.70 (+0.00%)
120
+ (wmma4x2+warp2x4): [' -9.0 ' , ' -144.875 ' ], time:1.237440ms, swizzle: NOOP, TFLOPS: 111.07(+89.22%)
121
+ (wmma4x2+warp2x4+stage3): [' -9.0 ' , ' -144.875 ' ], time:0.725293ms, swizzle: NOOP, TFLOPS: 189.49(+70.61%)
122
+ (wmma4x2+warp2x4+stage3+dsmem): [' -9.0 ' , ' -144.875 ' ], time:0.723266ms, swizzle: NOOP, TFLOPS: 190.03(+0.28%)
123
+ (wmma4x2+warp2x4+stage3+swizzle): [' -9.0 ' , ' -144.875 ' ], time:0.702548ms, swizzle: 2048, TFLOPS: 195.63(+2.95%)
124
+ (wmma4x2+warp2x4+stage3+dsmem+swizzle): [' -9.0 ' , ' -144.875 ' ], time:0.702190ms, swizzle: 2048, TFLOPS: 195.73(+0.05%)
125
+ (wmma4x4+warp4x4+stage3+dsmem): [' -9.0 ' , ' -144.875 ' ], time:0.556564ms, swizzle: NOOP, TFLOPS: 246.94(+26.17%)
126
+ (cublas): [' -9.0 ' , ' -144.875 ' ], time:0.539851ms, swizzle: NOOP, TFLOPS: 254.59(+3.10%)
127
+ ----------------------------------------------------------------------------------------------------------------------------------
128
+ ```
129
+
83
130
### NVIDIA GeForce RTX 3080 Laptop
84
131
85
132
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,不过Laptop测试的性能数据不稳定,这部分看看就好,别太当真。
@@ -96,6 +143,7 @@ python3 hgemm.py --wmma-all
96
143
(cublas): [' 68.375 ' , ' -2.234375 ' ], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
97
144
----------------------------------------------------------------------------------------------------------------------------------
98
145
```
146
+
99
147
## 测试命令
100
148
101
149
``` bash
0 commit comments