Skip to content

Commit abb34fa

Browse files
authored
[HGEMM] clear tensor cache avoid OOM (#141)
* Update hgemm.py * Update README.md * Update README.md
1 parent 82b2898 commit abb34fa

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

hgemm/README.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack)
4141
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
4242
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Warp swizzle, Reg Double Buffers, Collective Store with Reg Reuse & Warp Shuffle)
43-
- [X] hgemm_mma_stages_tn_cute(MMA, Tile MMA/Warp, Copy Async, Stages, SMEM Swizzle)
43+
- [X] hgemm_mma_stages_block_swizzle_tn_cute(MMA, Tile MMA/Warp, Copy Async, Stages, Block Swizzle, SMEM Swizzle, Collective Store with SMEM)
4444
- [X] PyTorch bindings
4545

4646
</details>
@@ -100,23 +100,23 @@ M N K = 16384 16384 16384, Time = 0.07663001 0.07663534 0.07664947 s, A
100100

101101
./hgemm_cute.bin
102102
# NVIDIA L20
103-
ALGO = CuTe HGEMM TN STAGES=2
104-
M N K = 12544 12544 12544, Time = 0.03410432 0.03411466 0.03412787 s, AVG Performance = 115.7170 Tflops
105-
M N K = 12800 12800 12800, Time = 0.03612774 0.03613839 0.03614515 s, AVG Performance = 116.0623 Tflops
106-
M N K = 13056 13056 13056, Time = 0.03820646 0.03821117 0.03821466 s, AVG Performance = 116.4850 Tflops
107-
M N K = 13312 13312 13312, Time = 0.04039987 0.04054825 0.04059136 s, AVG Performance = 116.3557 Tflops
108-
M N K = 13568 13568 13568, Time = 0.04315751 0.04316447 0.04318515 s, AVG Performance = 115.7314 Tflops
109-
M N K = 13824 13824 13824, Time = 0.04540928 0.04541317 0.04541542 s, AVG Performance = 116.3454 Tflops
110-
M N K = 14080 14080 14080, Time = 0.04774707 0.04775066 0.04775833 s, AVG Performance = 116.9119 Tflops
111-
M N K = 14336 14336 14336, Time = 0.05077197 0.05078108 0.05079654 s, AVG Performance = 116.0412 Tflops
112-
M N K = 14592 14592 14592, Time = 0.05325619 0.05326203 0.05326848 s, AVG Performance = 116.6693 Tflops
113-
M N K = 14848 14848 14848, Time = 0.05650432 0.05652460 0.05653504 s, AVG Performance = 115.8234 Tflops
114-
M N K = 15104 15104 15104, Time = 0.05913191 0.05915228 0.05917798 s, AVG Performance = 116.5023 Tflops
115-
M N K = 15360 15360 15360, Time = 0.06275584 0.06281114 0.06284800 s, AVG Performance = 115.3897 Tflops
116-
M N K = 15616 15616 15616, Time = 0.06540698 0.06549893 0.06558515 s, AVG Performance = 116.2800 Tflops
117-
M N K = 15872 15872 15872, Time = 0.06917018 0.06926930 0.06936780 s, AVG Performance = 115.4474 Tflops
118-
M N K = 16128 16128 16128, Time = 0.07299482 0.07302656 0.07305421 s, AVG Performance = 114.8922 Tflops
119-
M N K = 16384 16384 16384, Time = 0.07693209 0.07698473 0.07704780 s, AVG Performance = 114.2576 Tflops
103+
ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048
104+
M N K = 12544 12544 12544, Time = 0.03413504 0.03414354 0.03415450 s, AVG Performance = 115.6191 Tflops
105+
M N K = 12800 12800 12800, Time = 0.03615642 0.03616481 0.03617178 s, AVG Performance = 115.9775 Tflops
106+
M N K = 13056 13056 13056, Time = 0.03821158 0.03821455 0.03821671 s, AVG Performance = 116.4747 Tflops
107+
M N K = 13312 13312 13312, Time = 0.04033536 0.04033894 0.04034560 s, AVG Performance = 116.9595 Tflops
108+
M N K = 13568 13568 13568, Time = 0.04318720 0.04319130 0.04319949 s, AVG Performance = 115.6595 Tflops
109+
M N K = 13824 13824 13824, Time = 0.04541542 0.04541942 0.04542157 s, AVG Performance = 116.3294 Tflops
110+
M N K = 14080 14080 14080, Time = 0.04770918 0.04772137 0.04772761 s, AVG Performance = 116.9836 Tflops
111+
M N K = 14336 14336 14336, Time = 0.05077402 0.05077955 0.05078426 s, AVG Performance = 116.0447 Tflops
112+
M N K = 14592 14592 14592, Time = 0.05324902 0.05326633 0.05327872 s, AVG Performance = 116.6599 Tflops
113+
M N K = 14848 14848 14848, Time = 0.05638758 0.05640591 0.05643162 s, AVG Performance = 116.0671 Tflops
114+
M N K = 15104 15104 15104, Time = 0.05892505 0.05893622 0.05894246 s, AVG Performance = 116.9294 Tflops
115+
M N K = 15360 15360 15360, Time = 0.06227354 0.06228111 0.06228992 s, AVG Performance = 116.3717 Tflops
116+
M N K = 15616 15616 15616, Time = 0.06492467 0.06493727 0.06496666 s, AVG Performance = 117.2858 Tflops
117+
M N K = 15872 15872 15872, Time = 0.06843085 0.06843873 0.06844723 s, AVG Performance = 116.8485 Tflops
118+
M N K = 16128 16128 16128, Time = 0.07200256 0.07200881 0.07201792 s, AVG Performance = 116.5161 Tflops
119+
M N K = 16384 16384 16384, Time = 0.07564493 0.07565752 0.07567462 s, AVG Performance = 116.2620 Tflops
120120

121121
./hgemm_cublas.bin
122122
# NVIDIA L20

hgemm/hgemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ def row2col(x: torch.Tensor):
479479
if args.enable_torch:
480480
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
481481
torch.cuda.synchronize()
482+
# Avoid OOM
483+
del a; a = None
484+
del b; b = None
485+
del c; c = None
486+
del b_col_major;
487+
b_col_major = None
488+
torch.cuda.empty_cache()
482489
pretty_print_line()
483490

484491
if args.plot_flops:

0 commit comments

Comments
 (0)