Skip to content

Commit c718049

Browse files
authored
[HGEMM] Update HGEMM benchmark scripts (#105)
* update hgemm benchmark scripts * update hgemm benchmark scripts
1 parent a0daf10 commit c718049

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

hgemm/hgemm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,16 @@ def run_benchmark(perf_func: callable,
226226
args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all,
227227
args.enable_cuda, args.enable_cuda_all, args.enable_torch)):
228228
run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
229+
if args.enable_torch:
230+
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
229231
if args.enable_mma_tn:
232+
MAX_TFLOPS = -1
233+
print("-" * 68 + "MMA(TN)" + "-" * 55)
230234
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3)
231235
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2)
232236
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
233237
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
234238
if not args.disable_cublas_tn:
235239
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c)
236-
if args.enable_torch:
237-
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
238240
torch.cuda.synchronize()
239241
print("-" * 130)

0 commit comments

Comments
 (0)