Skip to content

Commit 48af93d

Browse files
authored
[HGEMM] Add show_memory option to bench (#143)
1 parent 9bd2268 commit 48af93d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

hgemm/hgemm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_args():
2323
parser.add_argument("--verbose", "--v", action="store_true", help="Verbose")
2424
parser.add_argument("--show-matrix", "--show-m", action="store_true", help="Show output matrix values")
2525
parser.add_argument("--show-all-info", "--show-a", action="store_true", help="Show all the profile info")
26+
parser.add_argument("--show-memory", "--show-mm", action="store_true", help="Show gpu memory info")
2627
parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests")
2728
parser.add_argument("--enable-mma-tn", "--mma-tn", action="store_true", help="Enable TN MMA kernel tests")
2829
parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests")
@@ -497,9 +498,10 @@ def row2col(x: torch.Tensor):
497498
gc.collect()
498499
pretty_print_line()
499500

500-
pretty_print_line()
501-
print(torch.cuda.memory_summary())
502-
pretty_print_line()
501+
if args.show_memory:
502+
pretty_print_line()
503+
print(torch.cuda.memory_summary())
504+
pretty_print_line()
503505

504506
if args.plot_flops:
505507
plot_tflops()

0 commit comments

Comments
 (0)