Skip to content

Commit f75d8f6

Browse files
authored
[HGEMM] Add CuTe HGEMM with SMEM Swizzle (#134)
* Update hgemm.py * Update hgemm.cu * Create utils.h * Create hgemm_mma_stage_tn_cute.cu * Create makefile * Update hgemm_mma_stage_tn_cute.cu * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update .gitignore * Update README.md
1 parent aabee15 commit f75d8f6

File tree

8 files changed

+746
-25
lines changed

8 files changed

+746
-25
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
<img src='https://github.com/user-attachments/assets/c7d65fe5-9fb9-49a8-b962-a6c09bcc030a' height="225px" width="403px">
2626
</div>
2727

28-
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`sky blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
28+
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
2929

3030
|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
3131
|:---:|:---:|:---:|:---:|
@@ -36,8 +36,8 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's d
3636
|✔️|✔️|✔️|✔️|
3737
|Reg Double Buffers|Block Swizzle|Warp Swizzle|Collective Store(Warp Shfl)|
3838
|✔️|✔️|✔️|✔️|
39-
|Row Major(NN)|Col Major(TN)|SGEMM TF32|SMEM Swizzle(Permuted)|
40-
|✔️|✔️|✔️|...|
39+
|Row Major(NN)|Col Major(TN)|SGEMM TF32|SMEM Swizzle(CuTe)|
40+
|✔️|✔️|✔️|✔️|
4141

4242

4343

@@ -201,6 +201,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's d
201201
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./hgemm/hgemm_mma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
202202
| ✔️ [hgemm_mma_m16n8k16...stages*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
203203
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
204+
| ✔️ [hgemm_mma_stages_tn_cute*](./hgemm/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
204205
| ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
205206
| ✔️ [sgemv_k128_f32x4](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
206207
| ✔️ [sgemv_k16_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
@@ -397,5 +398,6 @@ How to contribute? please check [🌤🌤CONTRIBUTE🎉🎉](https://github.com/
397398
- [cuda_hgemm](https://github.com/Bruce-Lee-LY/cuda_hgemm)
398399
- [cuda-tensorcore-hgemm](https://github.com/nicolaswilde/cuda-tensorcore-hgemm)
399400
- [How_to_optimize_in_GPU](https://github.com/Liu-xiandong/How_to_optimize_in_GPU/tree/master/sgemv)
400-
401+
- [cute_gemm](https://github.com/weishengying/cute_gemm)
402+
401403
</details>

hgemm/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ __pycache__
1616
*.ncu*
1717
*.sqlite*
1818
*.engine
19+
*.bin
20+
*.out

hgemm/README.md

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
|✔️|✔️|✔️|✔️|
1212
|**Reg Double Buffers**|**Block Swizzle**|**Warp Swizzle**|**Collective Store(Reg Reuse&Warp Shfl)**|
1313
|✔️|✔️|✔️|✔️|
14-
|**Row Major(NN)**|**Col Major(TN)**|**SGEMM TF32**|**SMEM Swizzle/Permuted**|
15-
|✔️|✔️|✔️||
14+
|**Row Major(NN)**|**Col Major(TN)**|**SGEMM TF32**|**SMEM Swizzle(CuTe)**|
15+
|✔️|✔️|✔️|✔️|
1616

1717
<details>
1818
<summary> 🔑️ 点击查看所有支持的HGEMM Kernels! </summary>
@@ -46,7 +46,10 @@
4646

4747
## 测试命令
4848

49+
**Python**: 支持python脚本直接测试
50+
4951
```bash
52+
git submodule update --init --recursive --force
5053
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
5154
export TORCH_CUDA_ARCH_LIST=Ada
5255
python3 hgemm.py --wmma # test defalut wmma kernels for all MNK
@@ -56,19 +59,51 @@ python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma ke
5659
python3 hgemm.py --wmma-all # test all wmma kernels for all MNK
5760
python3 hgemm.py --mma-all # test all mma kernels for all MNK
5861
python3 hgemm.py --cuda-all --wmma-all --mma-all # test all kernels for all MNK
62+
python3 hgemm.py --cute-tn --no-default # test cute hgemm with smem swizzle for all MNK
5963
```
6064
如果需要绘制TFLOPS曲线图,需要先安装matplotlib,并指定--plot-flops(或--plot)选项:
6165
```bash
6266
python3 -m pip install matplotlib
6367
# topk指定只绘制性能最好的topk个kernel
64-
python3 hgemm.py --mma-all --plot --topk 8
68+
python3 hgemm.py --mma-all --plot --topk 8
69+
python3 hgemm.py --cute-tn --no-default --plot # test cute hgemm with smem swizzle for all MNK
70+
```
71+
72+
**C++**: C++测试目前仅支持CuTe HGEMM,C++ bin方式测试的性能数据会略优于python测试方式,可能是torch binding引入了一定的开销。
73+
```bash
74+
make
75+
./hgemm_cute.bin
76+
# NVIDIA L20
77+
algo = CUTE HGEMM Stages 2
78+
M N K = 256 256 256, Time = 0.00001946 0.00002007 0.00002048 s, AVG Performance = 1.6718 Tflops
79+
M N K = 512 512 512, Time = 0.00003174 0.00003277 0.00003379 s, AVG Performance = 8.1920 Tflops
80+
M N K = 768 768 768, Time = 0.00004506 0.00004608 0.00004710 s, AVG Performance = 19.6608 Tflops
81+
M N K = 1024 1024 1024, Time = 0.00005837 0.00005929 0.00006042 s, AVG Performance = 36.2202 Tflops
82+
M N K = 9216 9216 9216, Time = 0.01371546 0.01371679 0.01371853 s, AVG Performance = 114.1314 Tflops
83+
M N K = 9472 9472 9472, Time = 0.01458586 0.01458924 0.01460531 s, AVG Performance = 116.4991 Tflops
84+
M N K = 9728 9728 9728, Time = 0.01597747 0.01597931 0.01598157 s, AVG Performance = 115.2239 Tflops
85+
M N K = 9984 9984 9984, Time = 0.01741721 0.01742008 0.01743462 s, AVG Performance = 114.2598 Tflops
86+
M N K = 10240 10240 10240, Time = 0.01839923 0.01840046 0.01840230 s, AVG Performance = 116.7081 Tflops
87+
M N K = 10496 10496 10496, Time = 0.01993421 0.01993523 0.01993728 s, AVG Performance = 116.0059 Tflops
88+
M N K = 10752 10752 10752, Time = 0.02151629 0.02151956 0.02153472 s, AVG Performance = 115.5219 Tflops
89+
M N K = 11008 11008 11008, Time = 0.02315571 0.02315663 0.02315878 s, AVG Performance = 115.2073 Tflops
90+
M N K = 11264 11264 11264, Time = 0.02484634 0.02484808 0.02484941 s, AVG Performance = 115.0311 Tflops
91+
M N K = 11520 11520 11520, Time = 0.02659226 0.02659430 0.02659840 s, AVG Performance = 114.9738 Tflops
92+
M N K = 11776 11776 11776, Time = 0.02780057 0.02780426 0.02781082 s, AVG Performance = 117.4660 Tflops
93+
M N K = 12032 12032 12032, Time = 0.03024179 0.03024701 0.03025818 s, AVG Performance = 115.1757 Tflops
94+
M N K = 12288 12288 12288, Time = 0.03214848 0.03215698 0.03217306 s, AVG Performance = 115.3980 Tflops
95+
M N K = 12544 12544 12544, Time = 0.03410842 0.03411661 0.03412173 s, AVG Performance = 115.7104 Tflops
96+
M N K = 12800 12800 12800, Time = 0.03612979 0.03613184 0.03613491 s, AVG Performance = 116.0833 Tflops
97+
M N K = 13056 13056 13056, Time = 0.03820134 0.03820769 0.03821671 s, AVG Performance = 116.4956 Tflops
98+
M N K = 15872 15872 15872, Time = 0.06917632 0.06927145 0.06936883 s, AVG Performance = 115.4438 Tflops
99+
M N K = 16128 16128 16128, Time = 0.07299379 0.07302472 0.07304806 s, AVG Performance = 114.8951 Tflops
65100
```
66101

67102
## 目前性能
68103

69104
### NVIDIA L20
70105

71-
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX实现smem swizzle/permute
106+
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts
72107

73108
<div id="NV-L20"></div>
74109

@@ -227,7 +262,7 @@ NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/
227262
```C
228263
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
229264
```
230-
本项目目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX实现smem swizzle/permute
265+
目前通过padding和smem swizzle的方式缓解bank conflicts。对于NN layout,使用smem padding缓解bank conflicts;对于TN layout,通过cutlass cute的smem swizzle/permuted消除bank conflicts
231266
232267
### 双缓冲 Double Buffers
233268

hgemm/hgemm.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,10 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch:
10231023
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
10241024
// from hgemm_mma_stage_tn.cu
10251025
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1026-
1026+
#ifdef ENBLE_CUTE_HGEMM
1027+
// from hgemm_mma_stage_tn_cute.cu
1028+
void hgemm_mma_stages_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1029+
#endif
10271030

10281031
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10291032
// CUDA Cores FP16
@@ -1067,5 +1070,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10671070
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr)
10681071
// TN: A row major MxK, B col major NxK, C row major MxN
10691072
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn)
1073+
// cute hgemm
1074+
#ifdef ENBLE_CUTE_HGEMM
1075+
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_stages_tn_cute)
1076+
#endif
10701077
}
10711078

hgemm/hgemm.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import time
34
from torch.utils.cpp_extension import load
@@ -30,6 +31,8 @@ def get_args():
3031
parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests")
3132
parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests")
3233
parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul")
34+
parser.add_argument("--enable-cute-tn", "--cute-tn", action="store_true", help="Enable cute hgemm matmul")
35+
parser.add_argument("--enable-cute", "--cute", action="store_true", help="Enable cute hgemm matmul")
3336
parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm")
3437
parser.add_argument("--disable-cublas-tn", "--no-cublas-tn", action="store_true", help="Disable cublas TN hgemm")
3538
parser.add_argument("--sleep-duration", "--sleep", type=float, default=0.1, help="Sleep duration")
@@ -42,6 +45,7 @@ def get_args():
4245
parser.add_argument("--save-dir", "--dir", type=str, default="./", help="Save dir for plot")
4346
return parser.parse_args()
4447

48+
4549
args = get_args()
4650
print(args)
4751

@@ -58,15 +62,25 @@ def get_device_capability():
5862
return torch.cuda.get_device_capability(torch.cuda.current_device())
5963

6064

61-
# Load the CUDA kernel as a python module
62-
print(f"Loading hgemm lib on device: {get_device_name()}, capability: {get_device_capability()} ...")
65+
def get_build_sources():
66+
build_sources = [
67+
'hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
68+
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
69+
'hgemm_mma.cu', 'hgemm_mma_stage.cu',
70+
'hgemm_mma_stage_tn.cu'
71+
]
72+
# if args.enable_cute_tn:
73+
# build_sources.append('hgemm_mma_stage_tn_cute.cu')
74+
build_sources.append('hgemm_mma_stage_tn_cute.cu')
75+
return build_sources
6376

64-
lib = load(name='hgemm_lib',
65-
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
66-
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
67-
'hgemm_mma.cu', 'hgemm_mma_stage.cu',
68-
'hgemm_mma_stage_tn.cu'],
69-
extra_cuda_cflags=[
77+
78+
def get_project_dir():
79+
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
80+
81+
82+
def get_build_cuda_cflags():
83+
extra_cuda_cflags=[
7084
"-O3",
7185
"-U__CUDA_NO_HALF_OPERATORS__",
7286
"-U__CUDA_NO_HALF_CONVERSIONS__",
@@ -94,7 +108,23 @@ def get_device_capability():
94108
# spill loads: 则是指将之前溢出到栈上的数据重新加载回寄存器。
95109
"-Xptxas -v",
96110
# "-maxrregcount=128 -Xptxas -dlcm=cg" if args.reduce_reg else ""
97-
],
111+
]
112+
# extra cuda flags for cute hgemm
113+
project_dir = get_project_dir()
114+
extra_cuda_cflags.append('-DNO_CUTE_HGEMM_BIN')
115+
extra_cuda_cflags.append('-DENBLE_CUTE_HGEMM')
116+
extra_cuda_cflags.append(f'-I {project_dir}')
117+
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/include')
118+
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/tools/util/include')
119+
120+
return extra_cuda_cflags
121+
122+
# Load the CUDA kernel as a python module
123+
print(f"Loading hgemm lib on device: {get_device_name()}, capability: {get_device_capability()} ...")
124+
125+
lib = load(name='hgemm_lib',
126+
sources=get_build_sources(),
127+
extra_cuda_cflags=get_build_cuda_cflags(),
98128
extra_cflags=['-std=c++17'],
99129
verbose=args.verbose)
100130

@@ -254,6 +284,7 @@ def plot_tflops():
254284
STATIS_INFO["(best)"] = get_best_tflops()
255285
draw_tags = topk_tflops
256286
draw_tags.append("(cublas)")
287+
draw_tags.append("tn(cublas)")
257288
draw_tags.append("(best)")
258289

259290
def skip_it(tag: str) -> bool:
@@ -269,10 +300,10 @@ def skip_it(tag: str) -> bool:
269300
if skip_it(tag):
270301
continue
271302
if "cublas" in tag:
272-
ax.plot(tflops, label=tag, linewidth=3)
303+
ax.plot(tflops, label=tag, linewidth=3, color='orange')
273304
else:
274305
if "best" in tag and not args.no_plot_best:
275-
ax.plot(tflops, label=tag, linewidth=4)
306+
ax.plot(tflops, label=tag, linewidth=4, color='blue')
276307
else:
277308
ax.plot(tflops, label=tag, linestyle='--')
278309

@@ -400,15 +431,21 @@ def get_mnk(sep: int = args.SEP):
400431
run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
401432
if args.enable_torch:
402433
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
403-
if args.enable_mma_tn:
434+
# TN layout: A row major with shape [M,K], B col major with shape [N,K]
435+
if any((args.enable_mma_tn, args.enable_cute_tn)):
404436
MAX_TFLOPS = -1
405-
print("-" * 68 + "MMA(TN)" + "-" * 55)
437+
print("-" * 68 + "TN" + "-" * 60)
438+
if args.enable_mma_tn:
406439
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3)
407440
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2)
408441
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
409442
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
410-
if not args.disable_cublas_tn:
411-
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c)
443+
if args.enable_cute_tn:
444+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b.transpose(1, 0), "tn(cute+swizzle<smem>+stage4)", c, stages=4)
445+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b.transpose(1, 0), "tn(cute+swizzle<smem>+stage3)", c, stages=3)
446+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b.transpose(1, 0), "tn(cute+swizzle<smem>+stage2)", c, stages=2)
447+
if not args.disable_cublas_tn and any((args.enable_mma_tn, args.enable_cute_tn)):
448+
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c)
412449
torch.cuda.synchronize()
413450
print("-" * 130)
414451

0 commit comments

Comments
 (0)