Skip to content

Commit 3f5ace3

Browse files
authored
[HGEMM] Add PyTorch HGEMM profile (#59)
* Create prof.py * Update .gitignore * Update .gitignore * Update README.md
1 parent cb869e2 commit 3f5ace3

File tree

4 files changed

+170
-1
lines changed

4 files changed

+170
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ __pycache__
1313
*.pt
1414
*.pth
1515
*.nsys*
16-
*.sqlite
16+
*.ncu*
17+
*.sqlite*
1718
*.engine

hgemm/.gitignore

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+
__pycache__
11+
*.onnx
12+
*.engine
13+
*.pt
14+
*.pth
15+
*.nsys*
16+
*.ncu*
17+
*.sqlite*
18+
*.engine

hgemm/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,68 @@ cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
115115
116116
```
117117

118+
## PyTorch HGEMM Profile
119+
120+
在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn kernel,内部实际使用HMMA(Tensor Cores)进行计算。
121+
122+
```bash
123+
ncu -o hgemm.prof -f python3 prof.py
124+
nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py
125+
```
126+
- 日志
127+
128+
```bash
129+
==PROF== Connected to process 367502 (/usr/bin/python3.10)
130+
==PROF== Profiling "unrolled_elementwise_kernel" - 0: 0%....50%....100% - 8 passes
131+
==PROF== Profiling "unrolled_elementwise_kernel" - 1: 0%....50%....100% - 8 passes
132+
==PROF== Profiling "unrolled_elementwise_kernel" - 2: 0%....50%....100% - 8 passes
133+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 3: 0%....50%....100% - 8 passes
134+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 4: 0%....50%....100% - 8 passes
135+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 5: 0%....50%....100% - 8 passes
136+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 6: 0%....50%....100% - 8 passes
137+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 7: 0%....50%....100% - 8 passes
138+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 8: 0%....50%....100% - 8 passes
139+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 9: 0%....50%....100% - 8 passes
140+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 10: 0%....50%....100% - 8 passes
141+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 11: 0%....50%....100% - 8 passes
142+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 12: 0%....50%....100% - 8 passes
143+
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 13: 0%....50%....100% - 8 passes
144+
```
145+
146+
- SASS
147+
148+
```C
149+
310 00007f41 37d5b850 LDSM.16.M88.4 R192, [R169+UR8+0x2000]
150+
311 00007f41 37d5b860 LDSM.16.M88.4 R196, [R169+UR8+0x2800]
151+
312 00007f41 37d5b870 @!P0 BRA.U 0x7f4137d5c3f0
152+
313 00007f41 37d5b880 HMMA.1688.F32 R0, R176, R192, R0
153+
314 00007f41 37d5b890 LDSM.16.MT88.4 R184, [R167+UR8+0x400]
154+
315 00007f41 37d5b8a0 HMMA.1688.F32 R32, R178, R192, R32
155+
316 00007f41 37d5b8b0 LDSM.16.M88.4 R200, [R170+UR8+0x2000]
156+
317 00007f41 37d5b8c0 HMMA.1688.F32 R64, R180, R192, R64
157+
318 00007f41 37d5b8d0 LDSM.16.MT88.4 R188, [R168+UR8+0x400]
158+
319 00007f41 37d5b8e0 HMMA.1688.F32 R96, R182, R192, R96
159+
320 00007f41 37d5b8f0 LDSM.16.M88.4 R204, [R170+UR8+0x2800]
160+
321 00007f41 37d5b900 HMMA.1688.F32 R100, R182, R193, R100
161+
322 00007f41 37d5b910 HMMA.1688.F32 R68, R180, R193, R68
162+
323 00007f41 37d5b920 HMMA.1688.F32 R36, R178, R193, R36
163+
324 00007f41 37d5b930 HMMA.1688.F32 R4, R176, R193, R4
164+
325 00007f41 37d5b940 HMMA.1688.F32 R8, R176, R194, R8
165+
326 00007f41 37d5b950 HMMA.1688.F32 R40, R178, R194, R40
166+
327 00007f41 37d5b960 HMMA.1688.F32 R72, R180, R194, R72
167+
328 00007f41 37d5b970 HMMA.1688.F32 R104, R182, R194, R104
168+
329 00007f41 37d5b980 HMMA.1688.F32 R108, R182, R195, R108
169+
330 00007f41 37d5b990 HMMA.1688.F32 R76, R180, R195, R76
170+
331 00007f41 37d5b9a0 HMMA.1688.F32 R44, R178, R195, R44
171+
332 00007f41 37d5b9b0 HMMA.1688.F32 R12, R176, R195, R12
172+
333 00007f41 37d5b9c0 HMMA.1688.F32 R16, R176, R196, R16
173+
334 00007f41 37d5b9d0 HMMA.1688.F32 R48, R178, R196, R48
174+
335 00007f41 37d5b9e0 HMMA.1688.F32 R80, R180, R196, R80
175+
336 00007f41 37d5b9f0 HMMA.1688.F32 R112, R182, R196, R112
176+
337 00007f41 37d5ba00 HMMA.1688.F32 R116, R182, R197, R116
177+
```
178+
179+
118180

119181
## 参考文献
120182

hgemm/prof.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import time
3+
from torch.utils.cpp_extension import load
4+
from functools import partial
5+
from typing import Optional
6+
7+
torch.set_grad_enabled(False)
8+
9+
# # Load the CUDA kernel as a python module
10+
# lib = load(name='hgemm_lib',
11+
# sources=['hgemm.cu'],
12+
# extra_cuda_cflags=[
13+
# "-O3",
14+
# "-U__CUDA_NO_HALF_OPERATORS__",
15+
# "-U__CUDA_NO_HALF_CONVERSIONS__",
16+
# "-U__CUDA_NO_HALF2_OPERATORS__",
17+
# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
18+
# "--expt-relaxed-constexpr",
19+
# "--expt-extended-lambda",
20+
# "--use_fast_math"
21+
# ],
22+
# extra_cflags=['-std=c++17'])
23+
24+
25+
def run_benchmark(perf_func: callable,
26+
a: torch.Tensor, b: torch.Tensor,
27+
tag: str, out: Optional[torch.Tensor] = None,
28+
warmup: int = 1, iters: int = 10,
29+
show_all: bool = False):
30+
if out is not None:
31+
out.fill_(0)
32+
if out is not None:
33+
for i in range(warmup):
34+
perf_func(a, b, out)
35+
else:
36+
for i in range(warmup):
37+
_ = perf_func(a, b)
38+
39+
torch.cuda.synchronize()
40+
start = time.time()
41+
# iters
42+
if out is not None:
43+
for i in range(iters):
44+
perf_func(a, b, out)
45+
else:
46+
for i in range(iters):
47+
out = perf_func(a, b)
48+
torch.cuda.synchronize()
49+
end = time.time()
50+
total_time = (end - start) * 1000 # ms
51+
mean_time = total_time / iters
52+
out_info = f"out_{tag}"
53+
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
54+
out_val = [round(v, 8) for v in out_val]
55+
out_val = [f"{v:<12}" for v in out_val]
56+
print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms")
57+
if show_all: print(out)
58+
return out.clone(), mean_time
59+
60+
61+
# Ms = [1024, 2048, 4096]
62+
# Ns = [1024, 2048, 4096]
63+
# Ks = [256, 512, 1024]
64+
Ms = [1024]
65+
Ns = [1024]
66+
Ks = [256]
67+
MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks]
68+
for (M, N, K) in MNKs:
69+
print("-" * 110)
70+
print(" " * 45 + f"M={M}, N={N}, K={K}")
71+
a = torch.randn((M, K)).cuda().half().contiguous()
72+
b = torch.randn((K, N)).cuda().half().contiguous()
73+
c = torch.randn((M, N)).cuda().half().contiguous()
74+
# run_benchmark(lib.hgemm_naive_f16, a, b, "f16", c)
75+
# run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c)
76+
# run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(t4x4bcf)", c)
77+
# run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(t4x4offset)", c)
78+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c)
79+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_bcf, a, b, "f16x4(t8x8bcf)", c)
80+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack, a, b, "f16x4pack(t8x8sk)", c)
81+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(bcf)", c)
82+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(bcf+offset)", c)
83+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(bcf)", c)
84+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_offset, a, b, "f16x8pack(bcf+offset)", c)
85+
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(dbuf)", c)
86+
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
87+
print("-" * 110)
88+

0 commit comments

Comments
 (0)