@@ -372,15 +372,17 @@ def matmul(a, b, activation=""):
372372else :
373373 print ("❌ Triton and Torch differ" )
374374
375- TORCH_HAS_FP8 = hasattr (torch , "float8_e5m2" )
376- if TORCH_HAS_FP8 and is_cuda ():
375+ TORCH_HAS_FP8 = hasattr (torch , "float8_e4m3fn" ) or hasattr (torch , "float8_e4m3fnuz" )
376+
377+ if TORCH_HAS_FP8 :
378+ fp8_dtype = torch .float8_e4m3fn if is_cuda () else torch .float8_e4m3fnuz
377379 torch .manual_seed (0 )
378380 a = torch .randn ((512 , 512 ), device = DEVICE , dtype = torch .float16 )
379381 b = torch .randn ((512 , 512 ), device = DEVICE , dtype = torch .float16 )
380- a = a .to (torch . float8_e5m2 )
382+ a = a .to (fp8_dtype )
381383 # pre-transpose b for efficiency.
382384 b = b .T
383- b = b .to (torch . float8_e5m2 )
385+ b = b .to (fp8_dtype )
384386 triton_output = matmul (a , b )
385387 torch_output = torch .matmul (a .to (torch .float16 ), b .to (torch .float16 ))
386388 print (f"triton_output_with_fp8_inputs={ triton_output } " )
@@ -404,7 +406,7 @@ def matmul(a, b, activation=""):
404406
405407configs = []
406408for fp8_inputs in [False , True ]:
407- if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda () ):
409+ if fp8_inputs and (not TORCH_HAS_FP8 ):
408410 continue
409411 configs .append (
410412 triton .testing .Benchmark (
@@ -413,8 +415,8 @@ def matmul(a, b, activation=""):
413415 line_arg = "provider" , # Argument name whose value corresponds to a different line in the plot
414416 # Possible values for `line_arg`
415417 # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
416- line_vals = ["triton" ] if fp8_inputs else [ ref_lib .lower (), "triton" ], # Label name for the lines
417- line_names = ["Triton" ] if fp8_inputs else [ ref_lib , "Triton" ], # Line styles
418+ line_vals = [ref_lib .lower (), "triton" ], # Label name for the lines
419+ line_names = [ref_lib , "Triton" ], # Line styles
418420 styles = [("green" , "-" ), ("blue" , "-" )],
419421 ylabel = "TFLOPS" , # Label name for the y-axis
420422 plot_name = "matmul-performance-" +
@@ -428,12 +430,19 @@ def benchmark(M, N, K, provider, fp8_inputs):
428430 a = torch .randn ((M , K ), device = DEVICE , dtype = torch .float16 )
429431 b = torch .randn ((K , N ), device = DEVICE , dtype = torch .float16 )
430432 if TORCH_HAS_FP8 and fp8_inputs :
431- a = a .to (torch .float8_e5m2 )
433+ fp8_dtype = torch .float8_e4m3fn if is_cuda () else torch .float8_e4m3fnuz
434+ a = a .to (fp8_dtype )
432435 b = b .T
433- b = b .to (torch . float8_e5m2 )
436+ b = b .to (fp8_dtype )
434437 quantiles = [0.5 , 0.2 , 0.8 ]
435438 if provider == ref_lib .lower ():
436- ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
439+ if fp8_inputs :
440+ one_device = torch .tensor (1. , device = a .device , dtype = torch .float32 )
441+ ref_fn = lambda : torch ._scaled_mm (a , b , scale_a = one_device , scale_b = one_device , out_dtype = torch .float16 ,
442+ use_fast_accum = True )
443+ else :
444+ ref_fn = lambda : torch .matmul (a , b )
445+ ms , min_ms , max_ms = triton .testing .do_bench (ref_fn , quantiles = quantiles )
437446 if provider == 'triton' :
438447 ms , min_ms , max_ms = triton .testing .do_bench (lambda : matmul (a , b ), quantiles = quantiles )
439448 perf = lambda ms : 2 * M * N * K * 1e-12 / (ms * 1e-3 )
0 commit comments