Skip to content

Commit 9bd2268

Browse files
authored
[HGEMM] Add gc.collect to HGEMM bench script (#142)
* Update hgemm.py * Update hgemm.py
1 parent abb34fa commit 9bd2268

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

hgemm/hgemm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import gc
23
import torch
34
import time
45
from torch.utils.cpp_extension import load
@@ -147,7 +148,6 @@ def get_build_cuda_cflags():
147148
CUBLAS_TOTAL_TFLOPS = 0
148149
CUBLAS_TN_TOTAL_TFLOPS = 0
149150

150-
151151
def make_block_swizzle_stride(N: int, K: int):
152152
# make swizzle stride as N/8,N/4,N/2 and multiples of 256
153153
if args.swizzle_factor is None:
@@ -162,7 +162,7 @@ def make_block_swizzle_stride(N: int, K: int):
162162

163163
return swizzle_stride
164164

165-
165+
@torch.no_grad
166166
def run_benchmark(perf_func: callable,
167167
a: torch.Tensor, b: torch.Tensor,
168168
tag: str, out: Optional[torch.Tensor] = None,
@@ -216,8 +216,9 @@ def run_benchmark(perf_func: callable,
216216
total_time = (end - start) * 1000 # ms
217217
mean_time = total_time / iters
218218
out_info = f"{tag}"
219-
out_val_first = out.flatten()[:2].detach().cpu().numpy().tolist()
220-
out_val_last = out.flatten()[-2:].detach().cpu().numpy().tolist()
219+
out_flat = out.flatten()
220+
out_val_first = out_flat[:2].detach().cpu().numpy().tolist()
221+
out_val_last = out_flat[-2:].detach().cpu().numpy().tolist()
221222
out_val = [out_val_first[0], out_val_last[-1]]
222223
out_val = [round(v, 8) for v in out_val]
223224
out_val = [f"{v:<12}"[:10] for v in out_val]
@@ -254,6 +255,10 @@ def run_benchmark(perf_func: callable,
254255
CUBLAS_TOTAL_TFLOPS += TFLOPS
255256

256257
torch.cuda.synchronize()
258+
del out_flat
259+
out_flat = None
260+
gc.collect()
261+
torch.cuda.empty_cache()
257262
time.sleep(args.sleep_duration)
258263
return out, mean_time
259264

@@ -274,6 +279,7 @@ def get_topk_tflops():
274279
return list(dict(topk_tflops[:args.plot_topk]).keys())
275280

276281

282+
@torch.no_grad
277283
def get_best_tflops():
278284
all_tflops = []
279285
for tag, tflops in STATIS_INFO.items():
@@ -345,6 +351,7 @@ def get_mnk(sep: int = args.SEP):
345351
return Ms, Ns, Ks
346352

347353

354+
@torch.no_grad
348355
def row2col(x: torch.Tensor):
349356
# convert a row major tensor -> col major with contiguous storage
350357
x_trans = x.t()
@@ -485,8 +492,14 @@ def row2col(x: torch.Tensor):
485492
del c; c = None
486493
del b_col_major;
487494
b_col_major = None
495+
gc.collect()
488496
torch.cuda.empty_cache()
497+
gc.collect()
489498
pretty_print_line()
490499

500+
pretty_print_line()
501+
print(torch.cuda.memory_summary())
502+
pretty_print_line()
503+
491504
if args.plot_flops:
492505
plot_tflops()

0 commit comments

Comments
 (0)