Skip to content

Commit c3aea10

Browse files
mgoinyewentao256
andauthored
[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel (#23280)
Signed-off-by: mgoin <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent d4fd276 commit c3aea10

File tree

13 files changed

+222
-1261
lines changed

13 files changed

+222
-1261
lines changed

benchmarks/kernels/bench_block_fp8_gemm.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import torch
55

66
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
7-
w8a8_block_fp8_matmul,
7+
apply_w8a8_block_fp8_linear,
8+
)
9+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10+
CUTLASS_BLOCK_FP8_SUPPORTED,
811
)
912
from vllm.platforms import current_platform
1013
from vllm.triton_utils import triton as vllm_triton
@@ -29,45 +32,62 @@
2932
]
3033

3134

32-
def build_w8a8_block_fp8_runner(M, N, K, block_size, device):
35+
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
3336
"""Build runner function for w8a8 block fp8 matmul."""
3437
factor_for_scale = 1e-2
3538

3639
fp8_info = torch.finfo(torch.float8_e4m3fn)
3740
fp8_max, fp8_min = fp8_info.max, fp8_info.min
3841

3942
# Create random FP8 tensors
40-
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
41-
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
43+
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
4244

43-
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
44-
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
45+
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
46+
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
4547

4648
# Create scales
4749
block_n, block_k = block_size[0], block_size[1]
4850
n_tiles = (N + block_n - 1) // block_n
4951
k_tiles = (K + block_k - 1) // block_k
5052

51-
As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale
5253
Bs = (
5354
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
5455
* factor_for_scale
5556
)
5657

58+
# SM90 CUTLASS requires row-major format for scales
59+
if use_cutlass and current_platform.is_device_capability(90):
60+
Bs = Bs.T.contiguous()
61+
5762
def run():
58-
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16)
63+
if use_cutlass:
64+
return apply_w8a8_block_fp8_linear(
65+
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
66+
)
67+
else:
68+
return apply_w8a8_block_fp8_linear(
69+
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
70+
)
5971

6072
return run
6173

6274

75+
# Determine available providers
76+
available_providers = ["torch-bf16", "w8a8-block-fp8-triton"]
77+
plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
78+
79+
if CUTLASS_BLOCK_FP8_SUPPORTED:
80+
available_providers.append("w8a8-block-fp8-cutlass")
81+
82+
6383
@vllm_triton.testing.perf_report(
6484
vllm_triton.testing.Benchmark(
6585
x_names=["batch_size"],
6686
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
6787
x_log=False,
6888
line_arg="provider",
69-
line_vals=["torch-bf16", "w8a8-block-fp8"],
70-
line_names=["torch-bf16", "w8a8-block-fp8"],
89+
line_vals=available_providers,
90+
line_names=available_providers,
7191
ylabel="TFLOP/s (larger is better)",
7292
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
7393
args={},
@@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
85105
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
86106
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
87107
)
88-
else: # w8a8-block-fp8
89-
run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device)
108+
elif provider == "w8a8-block-fp8-triton":
109+
run_w8a8_triton = build_w8a8_block_fp8_runner(
110+
M, N, K, block_size, device, use_cutlass=False
111+
)
112+
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
113+
lambda: run_w8a8_triton(), quantiles=quantiles
114+
)
115+
elif provider == "w8a8-block-fp8-cutlass":
116+
run_w8a8_cutlass = build_w8a8_block_fp8_runner(
117+
M, N, K, block_size, device, use_cutlass=True
118+
)
90119
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
91-
lambda: run_w8a8(), quantiles=quantiles
120+
lambda: run_w8a8_cutlass(), quantiles=quantiles
92121
)
122+
else:
123+
raise ValueError(f"Unknown provider: {provider}")
93124

94125
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
95126
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)

csrc/cutlass_extensions/gemm/collective/collective_builder.hpp

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)