Skip to content

Commit ca24b10

Browse files
davidberard98ThomasRaoux
authored andcommitted
[TUTORIAL][03] use float8_e4m3fn(uz) instead of e5m2 and add PyTorch comparison (triton-lang#6850)
**Motivation**: Add a baseline from PyTorch (scaled_mm) for comparison. scaled_mm (the implementation of fp8 matmul in PyTorch) supports only float8_e4m3fn on NVIDIA, and on AMD the equivalent is float8_e4m3fnuz. to my knowledge, e4m3fn and e5m2 should have similar performance behavior on NVIDIA and AMD. Co-authored-by: Thomas Raoux <[email protected]>
1 parent 6bc0661 commit ca24b10

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,17 @@ def matmul(a, b, activation=""):
372372
else:
373373
print("❌ Triton and Torch differ")
374374

375-
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
376-
if TORCH_HAS_FP8 and is_cuda():
375+
TORCH_HAS_FP8 = hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz")
376+
377+
if TORCH_HAS_FP8:
378+
fp8_dtype = torch.float8_e4m3fn if is_cuda() else torch.float8_e4m3fnuz
377379
torch.manual_seed(0)
378380
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
379381
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
380-
a = a.to(torch.float8_e5m2)
382+
a = a.to(fp8_dtype)
381383
# pre-transpose b for efficiency.
382384
b = b.T
383-
b = b.to(torch.float8_e5m2)
385+
b = b.to(fp8_dtype)
384386
triton_output = matmul(a, b)
385387
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
386388
print(f"triton_output_with_fp8_inputs={triton_output}")
@@ -404,7 +406,7 @@ def matmul(a, b, activation=""):
404406

405407
configs = []
406408
for fp8_inputs in [False, True]:
407-
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
409+
if fp8_inputs and (not TORCH_HAS_FP8):
408410
continue
409411
configs.append(
410412
triton.testing.Benchmark(
@@ -413,8 +415,8 @@ def matmul(a, b, activation=""):
413415
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
414416
# Possible values for `line_arg`
415417
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
416-
line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"], # Label name for the lines
417-
line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles
418+
line_vals=[ref_lib.lower(), "triton"], # Label name for the lines
419+
line_names=[ref_lib, "Triton"], # Line styles
418420
styles=[("green", "-"), ("blue", "-")],
419421
ylabel="TFLOPS", # Label name for the y-axis
420422
plot_name="matmul-performance-" +
@@ -428,12 +430,19 @@ def benchmark(M, N, K, provider, fp8_inputs):
428430
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
429431
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
430432
if TORCH_HAS_FP8 and fp8_inputs:
431-
a = a.to(torch.float8_e5m2)
433+
fp8_dtype = torch.float8_e4m3fn if is_cuda() else torch.float8_e4m3fnuz
434+
a = a.to(fp8_dtype)
432435
b = b.T
433-
b = b.to(torch.float8_e5m2)
436+
b = b.to(fp8_dtype)
434437
quantiles = [0.5, 0.2, 0.8]
435438
if provider == ref_lib.lower():
436-
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
439+
if fp8_inputs:
440+
one_device = torch.tensor(1., device=a.device, dtype=torch.float32)
441+
ref_fn = lambda: torch._scaled_mm(a, b, scale_a=one_device, scale_b=one_device, out_dtype=torch.float16,
442+
use_fast_accum=True)
443+
else:
444+
ref_fn = lambda: torch.matmul(a, b)
445+
ms, min_ms, max_ms = triton.testing.do_bench(ref_fn, quantiles=quantiles)
437446
if provider == 'triton':
438447
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
439448
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)

0 commit comments

Comments
 (0)