@@ -226,14 +226,16 @@ def run_benchmark(perf_func: callable,
226226 args .enable_mma , args .enable_mma_all , args .enable_wmma , args .enable_wmma_all ,
227227 args .enable_cuda , args .enable_cuda_all , args .enable_torch )):
228228 run_benchmark (lib .hgemm_cublas_tensor_op_nn , a , b , "(cublas)" , c )
229+ if args .enable_torch :
230+ run_benchmark (partial (torch .matmul , out = c ), a , b , "(torch)" )
229231 if args .enable_mma_tn :
232+ MAX_TFLOPS = - 1
233+ print ("-" * 68 + "MMA(TN)" + "-" * 55 )
230234 run_benchmark (lib .hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn , a , b .transpose (1 , 0 ), "tn(mma2x4+warp4x4+stage3+dsmem)" , c , stages = 3 )
231235 run_benchmark (lib .hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn , a , b .transpose (1 , 0 ), "tn(mma2x4+warp4x4+stage2+dsmem)" , c , stages = 2 )
232236 run_benchmark (lib .hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn , a , b .transpose (1 , 0 ), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)" , c , stages = 3 , swizzle = True )
233237 run_benchmark (lib .hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn , a , b .transpose (1 , 0 ), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)" , c , stages = 2 , swizzle = True )
234238 if not args .disable_cublas_tn :
235239 run_benchmark (lib .hgemm_cublas_tensor_op_tn , a , b .transpose (1 , 0 ), "tn(cublas)" , c )
236- if args .enable_torch :
237- run_benchmark (partial (torch .matmul , out = c ), a , b , "(torch)" )
238240 torch .cuda .synchronize ()
239241 print ("-" * 130 )
0 commit comments