1
1
import os
2
+ import gc
2
3
import torch
3
4
import time
4
5
from torch .utils .cpp_extension import load
@@ -147,7 +148,6 @@ def get_build_cuda_cflags():
147
148
CUBLAS_TOTAL_TFLOPS = 0
148
149
CUBLAS_TN_TOTAL_TFLOPS = 0
149
150
150
-
151
151
def make_block_swizzle_stride (N : int , K : int ):
152
152
# make swizzle stride as N/8,N/4,N/2 and multiples of 256
153
153
if args .swizzle_factor is None :
@@ -162,7 +162,7 @@ def make_block_swizzle_stride(N: int, K: int):
162
162
163
163
return swizzle_stride
164
164
165
-
165
+ @ torch . no_grad
166
166
def run_benchmark (perf_func : callable ,
167
167
a : torch .Tensor , b : torch .Tensor ,
168
168
tag : str , out : Optional [torch .Tensor ] = None ,
@@ -216,8 +216,9 @@ def run_benchmark(perf_func: callable,
216
216
total_time = (end - start ) * 1000 # ms
217
217
mean_time = total_time / iters
218
218
out_info = f"{ tag } "
219
- out_val_first = out .flatten ()[:2 ].detach ().cpu ().numpy ().tolist ()
220
- out_val_last = out .flatten ()[- 2 :].detach ().cpu ().numpy ().tolist ()
219
+ out_flat = out .flatten ()
220
+ out_val_first = out_flat [:2 ].detach ().cpu ().numpy ().tolist ()
221
+ out_val_last = out_flat [- 2 :].detach ().cpu ().numpy ().tolist ()
221
222
out_val = [out_val_first [0 ], out_val_last [- 1 ]]
222
223
out_val = [round (v , 8 ) for v in out_val ]
223
224
out_val = [f"{ v :<12} " [:10 ] for v in out_val ]
@@ -254,6 +255,10 @@ def run_benchmark(perf_func: callable,
254
255
CUBLAS_TOTAL_TFLOPS += TFLOPS
255
256
256
257
torch .cuda .synchronize ()
258
+ del out_flat
259
+ out_flat = None
260
+ gc .collect ()
261
+ torch .cuda .empty_cache ()
257
262
time .sleep (args .sleep_duration )
258
263
return out , mean_time
259
264
@@ -274,6 +279,7 @@ def get_topk_tflops():
274
279
return list (dict (topk_tflops [:args .plot_topk ]).keys ())
275
280
276
281
282
+ @torch .no_grad
277
283
def get_best_tflops ():
278
284
all_tflops = []
279
285
for tag , tflops in STATIS_INFO .items ():
@@ -345,6 +351,7 @@ def get_mnk(sep: int = args.SEP):
345
351
return Ms , Ns , Ks
346
352
347
353
354
+ @torch .no_grad
348
355
def row2col (x : torch .Tensor ):
349
356
# convert a row major tensor -> col major with contiguous storage
350
357
x_trans = x .t ()
@@ -485,8 +492,14 @@ def row2col(x: torch.Tensor):
485
492
del c ; c = None
486
493
del b_col_major ;
487
494
b_col_major = None
495
+ gc .collect ()
488
496
torch .cuda .empty_cache ()
497
+ gc .collect ()
489
498
pretty_print_line ()
490
499
500
+ pretty_print_line ()
501
+ print (torch .cuda .memory_summary ())
502
+ pretty_print_line ()
503
+
491
504
if args .plot_flops :
492
505
plot_tflops ()
0 commit comments