4
4
import torch
5
5
6
6
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
7
- w8a8_block_fp8_matmul ,
7
+ apply_w8a8_block_fp8_linear ,
8
+ )
9
+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
10
+ CUTLASS_BLOCK_FP8_SUPPORTED ,
8
11
)
9
12
from vllm .platforms import current_platform
10
13
from vllm .triton_utils import triton as vllm_triton
29
32
]
30
33
31
34
32
- def build_w8a8_block_fp8_runner (M , N , K , block_size , device ):
35
+ def build_w8a8_block_fp8_runner (M , N , K , block_size , device , use_cutlass ):
33
36
"""Build runner function for w8a8 block fp8 matmul."""
34
37
factor_for_scale = 1e-2
35
38
36
39
fp8_info = torch .finfo (torch .float8_e4m3fn )
37
40
fp8_max , fp8_min = fp8_info .max , fp8_info .min
38
41
39
42
# Create random FP8 tensors
40
- A_fp32 = (torch .rand (M , K , dtype = torch .float32 , device = device ) - 0.5 ) * 2 * fp8_max
41
- A = A_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
43
+ A_ref = (torch .rand (M , K , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
42
44
43
- B_fp32 = (torch .rand (N , K , dtype = torch .float32 , device = device ) - 0.5 ) * 2 * fp8_max
44
- B = B_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
45
+ B_ref = (torch .rand (N , K , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
46
+ B = B_ref .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
45
47
46
48
# Create scales
47
49
block_n , block_k = block_size [0 ], block_size [1 ]
48
50
n_tiles = (N + block_n - 1 ) // block_n
49
51
k_tiles = (K + block_k - 1 ) // block_k
50
52
51
- As = torch .rand (M , k_tiles , dtype = torch .float32 , device = device ) * factor_for_scale
52
53
Bs = (
53
54
torch .rand (n_tiles , k_tiles , dtype = torch .float32 , device = device )
54
55
* factor_for_scale
55
56
)
56
57
58
+ # SM90 CUTLASS requires row-major format for scales
59
+ if use_cutlass and current_platform .is_device_capability (90 ):
60
+ Bs = Bs .T .contiguous ()
61
+
57
62
def run ():
58
- return w8a8_block_fp8_matmul (A , B , As , Bs , block_size , torch .bfloat16 )
63
+ if use_cutlass :
64
+ return apply_w8a8_block_fp8_linear (
65
+ A_ref , B , block_size , Bs , cutlass_block_fp8_supported = True
66
+ )
67
+ else :
68
+ return apply_w8a8_block_fp8_linear (
69
+ A_ref , B , block_size , Bs , cutlass_block_fp8_supported = False
70
+ )
59
71
60
72
return run
61
73
62
74
75
+ # Determine available providers
76
+ available_providers = ["torch-bf16" , "w8a8-block-fp8-triton" ]
77
+ plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
78
+
79
+ if CUTLASS_BLOCK_FP8_SUPPORTED :
80
+ available_providers .append ("w8a8-block-fp8-cutlass" )
81
+
82
+
63
83
@vllm_triton .testing .perf_report (
64
84
vllm_triton .testing .Benchmark (
65
85
x_names = ["batch_size" ],
66
86
x_vals = [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 ],
67
87
x_log = False ,
68
88
line_arg = "provider" ,
69
- line_vals = [ "torch-bf16" , "w8a8-block-fp8" ] ,
70
- line_names = [ "torch-bf16" , "w8a8-block-fp8" ] ,
89
+ line_vals = available_providers ,
90
+ line_names = available_providers ,
71
91
ylabel = "TFLOP/s (larger is better)" ,
72
92
plot_name = "BF16 vs W8A8 Block FP8 GEMMs" ,
73
93
args = {},
@@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
85
105
ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
86
106
lambda : torch .nn .functional .linear (a , b ), quantiles = quantiles
87
107
)
88
- else : # w8a8-block-fp8
89
- run_w8a8 = build_w8a8_block_fp8_runner (M , N , K , block_size , device )
108
+ elif provider == "w8a8-block-fp8-triton" :
109
+ run_w8a8_triton = build_w8a8_block_fp8_runner (
110
+ M , N , K , block_size , device , use_cutlass = False
111
+ )
112
+ ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
113
+ lambda : run_w8a8_triton (), quantiles = quantiles
114
+ )
115
+ elif provider == "w8a8-block-fp8-cutlass" :
116
+ run_w8a8_cutlass = build_w8a8_block_fp8_runner (
117
+ M , N , K , block_size , device , use_cutlass = True
118
+ )
90
119
ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
91
- lambda : run_w8a8 (), quantiles = quantiles
120
+ lambda : run_w8a8_cutlass (), quantiles = quantiles
92
121
)
122
+ else :
123
+ raise ValueError (f"Unknown provider: { provider } " )
93
124
94
125
to_tflops = lambda t_ms : (2 * M * N * K ) * 1e-12 / (t_ms * 1e-3 )
95
126
return to_tflops (ms ), to_tflops (max_ms ), to_tflops (min_ms )
0 commit comments