1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- # fmt: off
4
3
# ruff: noqa: E501
5
4
import time
6
5
20
19
)
21
20
22
21
23
- def benchmark_shape (m : int ,
24
- n : int ,
25
- k : int ,
26
- warmup : int = 100 ,
27
- repeat : int = 10000 ,
28
- verbose : bool = False ) -> dict :
22
+ def benchmark_shape (
23
+ m : int ,
24
+ n : int ,
25
+ k : int ,
26
+ warmup : int = 100 ,
27
+ repeat : int = 10000 ,
28
+ verbose : bool = False ,
29
+ ) -> dict :
29
30
"""Benchmark all implementations for a specific (m, n, k) shape."""
30
31
if verbose :
31
32
print (f"\n === Benchmarking shape: m={ m } , n={ n } , k={ k } ===" )
32
33
33
34
# Create test tensors
34
- A = torch .randn ((m , k ), device = ' cuda' , dtype = torch .bfloat16 )
35
- B = torch .randn ((n , k ), device = ' cuda' , dtype = torch .bfloat16 )
35
+ A = torch .randn ((m , k ), device = " cuda" , dtype = torch .bfloat16 )
36
+ B = torch .randn ((n , k ), device = " cuda" , dtype = torch .bfloat16 )
36
37
37
38
# Reference result in BF16
38
39
torch .cuda .synchronize ()
@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
49
50
# Pre-quantize A for all implementations
50
51
A_deepgemm , A_scale_deepgemm = per_token_group_quant_fp8 (A , block_size [1 ])
51
52
A_scale_deepgemm = get_col_major_tma_aligned_tensor (A_scale_deepgemm )
52
- C_deepgemm = torch .empty ((m , n ), device = ' cuda' , dtype = torch .bfloat16 )
53
+ C_deepgemm = torch .empty ((m , n ), device = " cuda" , dtype = torch .bfloat16 )
53
54
A_vllm , A_scale_vllm = per_token_group_quant_fp8 (A , block_size [1 ])
54
55
A_vllm_cutlass , A_scale_vllm_cutlass = per_token_group_quant_fp8 (
55
- A , block_size [1 ], column_major_scales = True )
56
+ A , block_size [1 ], column_major_scales = True
57
+ )
56
58
57
59
# === DeepGEMM Implementation ===
58
60
def deepgemm_gemm ():
59
- fp8_gemm_nt (( A_deepgemm , A_scale_deepgemm ),
60
- (B_deepgemm , B_scale_deepgemm ),
61
- C_deepgemm )
61
+ fp8_gemm_nt (
62
+ ( A_deepgemm , A_scale_deepgemm ), (B_deepgemm , B_scale_deepgemm ), C_deepgemm
63
+ )
62
64
return C_deepgemm
63
65
64
66
# === vLLM Triton Implementation ===
65
67
def vllm_triton_gemm ():
66
- return w8a8_triton_block_scaled_mm (A_vllm ,
67
- B_vllm ,
68
- A_scale_vllm ,
69
- B_scale_vllm ,
70
- block_size ,
71
- output_dtype = torch .bfloat16 )
68
+ return w8a8_triton_block_scaled_mm (
69
+ A_vllm ,
70
+ B_vllm ,
71
+ A_scale_vllm ,
72
+ B_scale_vllm ,
73
+ block_size ,
74
+ output_dtype = torch .bfloat16 ,
75
+ )
72
76
73
77
# === vLLM CUTLASS Implementation ===
74
78
def vllm_cutlass_gemm ():
75
- return ops .cutlass_scaled_mm (A_vllm_cutlass ,
76
- B_vllm .T ,
77
- scale_a = A_scale_vllm_cutlass ,
78
- scale_b = B_scale_vllm .T ,
79
- out_dtype = torch .bfloat16 )
79
+ return ops .cutlass_scaled_mm (
80
+ A_vllm_cutlass ,
81
+ B_vllm .T ,
82
+ scale_a = A_scale_vllm_cutlass ,
83
+ scale_b = B_scale_vllm .T ,
84
+ out_dtype = torch .bfloat16 ,
85
+ )
80
86
81
87
# Run correctness check first
82
88
if verbose :
@@ -93,26 +99,23 @@ def vllm_cutlass_gemm():
93
99
print (f"DeepGEMM vs Reference difference: { deepgemm_diff :.6f} " )
94
100
print (f"vLLM Triton vs Reference difference: { vllm_triton_diff :.6f} " )
95
101
print (f"vLLM CUTLASS vs Reference difference: { vllm_cutlass_diff :.6f} " )
96
- print ("vLLM Triton vs DeepGEMM difference: "
97
- f"{ calc_diff (C_vllm_triton , C_deepgemm ):.6f} " )
98
- print ("vLLM CUTLASS vs DeepGEMM difference: "
99
- f"{ calc_diff (C_vllm_cutlass , C_deepgemm ):.6f} " )
102
+ print (
103
+ "vLLM Triton vs DeepGEMM difference: "
104
+ f"{ calc_diff (C_vllm_triton , C_deepgemm ):.6f} "
105
+ )
106
+ print (
107
+ "vLLM CUTLASS vs DeepGEMM difference: "
108
+ f"{ calc_diff (C_vllm_cutlass , C_deepgemm ):.6f} "
109
+ )
100
110
101
111
# Benchmark implementations
102
112
implementations = {
103
113
"DeepGEMM" : deepgemm_gemm ,
104
114
"vLLM Triton" : vllm_triton_gemm ,
105
- "vLLM CUTLASS" : vllm_cutlass_gemm
115
+ "vLLM CUTLASS" : vllm_cutlass_gemm ,
106
116
}
107
117
108
- benchmark_results = {
109
- "shape" : {
110
- "m" : m ,
111
- "n" : n ,
112
- "k" : k
113
- },
114
- "implementations" : {}
115
- }
118
+ benchmark_results = {"shape" : {"m" : m , "n" : n , "k" : k }, "implementations" : {}}
116
119
117
120
for name , func in implementations .items ():
118
121
# Warmup
@@ -140,38 +143,36 @@ def vllm_cutlass_gemm():
140
143
"tflops" : tflops ,
141
144
"gb_s" : gb_s ,
142
145
"diff" : {
143
- "DeepGEMM" :
144
- 0.0 if name == "DeepGEMM" else calc_diff ( func (), C_deepgemm ),
145
- "Reference" :
146
- deepgemm_diff if name == "DeepGEMM" else
147
- ( vllm_triton_diff
148
- if name == "vLLM Triton" else vllm_cutlass_diff )
149
- }
146
+ "DeepGEMM" : 0.0
147
+ if name == "DeepGEMM"
148
+ else calc_diff ( func (), C_deepgemm ),
149
+ "Reference" : deepgemm_diff
150
+ if name == "DeepGEMM"
151
+ else ( vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff ),
152
+ },
150
153
}
151
154
152
155
if verbose :
153
- print (
154
- f"{ name } : { avg_time_ms :.3f} ms, { tflops :.2f} TFLOPS, { gb_s :.2f} GB/s"
155
- )
156
+ print (f"{ name } : { avg_time_ms :.3f} ms, { tflops :.2f} TFLOPS, { gb_s :.2f} GB/s" )
156
157
157
158
# Calculate speedups
158
159
baseline = benchmark_results ["implementations" ]["DeepGEMM" ]["time_ms" ]
159
160
for name , data in benchmark_results ["implementations" ].items ():
160
161
if name != "DeepGEMM" :
161
162
speedup = baseline / data ["time_ms" ]
162
- benchmark_results ["implementations" ][name ][
163
- "speedup_vs_deepgemm" ] = speedup
163
+ benchmark_results ["implementations" ][name ]["speedup_vs_deepgemm" ] = speedup
164
164
if verbose :
165
- print (f"DeepGEMM is { 1 / speedup :.2f} x "
166
- f"{ 'faster' if 1 / speedup > 1 else 'slower' } than { name } " )
165
+ print (
166
+ f"DeepGEMM is { 1 / speedup :.2f} x "
167
+ f"{ 'faster' if 1 / speedup > 1 else 'slower' } than { name } "
168
+ )
167
169
168
- vllm_triton_time = benchmark_results ["implementations" ]["vLLM Triton" ][
169
- "time_ms" ]
170
- vllm_cutlass_time = benchmark_results ["implementations" ]["vLLM CUTLASS" ][
171
- "time_ms" ]
170
+ vllm_triton_time = benchmark_results ["implementations" ]["vLLM Triton" ]["time_ms" ]
171
+ vllm_cutlass_time = benchmark_results ["implementations" ]["vLLM CUTLASS" ]["time_ms" ]
172
172
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
173
- benchmark_results ["implementations" ]["vLLM CUTLASS" ][
174
- "speedup_vs_triton" ] = cutlass_vs_triton
173
+ benchmark_results ["implementations" ]["vLLM CUTLASS" ]["speedup_vs_triton" ] = (
174
+ cutlass_vs_triton
175
+ )
175
176
if verbose :
176
177
print (
177
178
f"vLLM CUTLASS is { cutlass_vs_triton :.2f} x "
@@ -183,8 +184,7 @@ def vllm_cutlass_gemm():
183
184
184
185
def format_table_row (values , widths ):
185
186
"""Format a row with specified column widths."""
186
- return "| " + " | " .join (f"{ val :{w }} "
187
- for val , w in zip (values , widths )) + " |"
187
+ return "| " + " | " .join (f"{ val :{w }} " for val , w in zip (values , widths )) + " |"
188
188
189
189
190
190
def print_table (headers , rows , title = None ):
@@ -292,67 +292,78 @@ def run_benchmarks(verbose: bool = False):
292
292
for result in all_results :
293
293
shape = result ["shape" ]
294
294
impl_data = result ["implementations" ]["DeepGEMM" ]
295
- deepgemm_rows .append ([
296
- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
297
- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} "
298
- ])
295
+ deepgemm_rows .append (
296
+ [
297
+ shape ["m" ],
298
+ shape ["n" ],
299
+ shape ["k" ],
300
+ f"{ impl_data ['time_us' ]:.1f} " ,
301
+ f"{ impl_data ['tflops' ]:.1f} " ,
302
+ f"{ impl_data ['gb_s' ]:.1f} " ,
303
+ ]
304
+ )
299
305
300
- print_table (deepgemm_headers ,
301
- deepgemm_rows ,
302
- title = "DeepGEMM Implementation:" )
306
+ print_table (deepgemm_headers , deepgemm_rows , title = "DeepGEMM Implementation:" )
303
307
304
308
# Print vLLM Triton table
305
- triton_headers = [
306
- "m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM"
307
- ]
309
+ triton_headers = ["m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM" ]
308
310
triton_rows = []
309
311
for result in all_results :
310
312
shape = result ["shape" ]
311
313
impl_data = result ["implementations" ]["vLLM Triton" ]
312
314
speedup = impl_data .get ("speedup_vs_deepgemm" , 1.0 )
313
- triton_rows .append ([
314
- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
315
- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} " ,
316
- format_speedup (speedup )
317
- ])
315
+ triton_rows .append (
316
+ [
317
+ shape ["m" ],
318
+ shape ["n" ],
319
+ shape ["k" ],
320
+ f"{ impl_data ['time_us' ]:.1f} " ,
321
+ f"{ impl_data ['tflops' ]:.1f} " ,
322
+ f"{ impl_data ['gb_s' ]:.1f} " ,
323
+ format_speedup (speedup ),
324
+ ]
325
+ )
318
326
319
- print_table (triton_headers ,
320
- triton_rows ,
321
- title = "vLLM Triton Implementation:" )
327
+ print_table (triton_headers , triton_rows , title = "vLLM Triton Implementation:" )
322
328
323
329
# Print vLLM CUTLASS table
324
330
cutlass_headers = [
325
- "m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM" ,
326
- "vs Triton"
331
+ "m" ,
332
+ "n" ,
333
+ "k" ,
334
+ "Time (μs)" ,
335
+ "TFLOPS" ,
336
+ "GB/s" ,
337
+ "vs DeepGEMM" ,
338
+ "vs Triton" ,
327
339
]
328
340
cutlass_rows = []
329
341
for result in all_results :
330
342
shape = result ["shape" ]
331
343
impl_data = result ["implementations" ]["vLLM CUTLASS" ]
332
344
vs_deepgemm = impl_data .get ("speedup_vs_deepgemm" , 1.0 )
333
345
vs_triton = impl_data .get ("speedup_vs_triton" , 1.0 )
334
- cutlass_rows .append ([
335
- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
336
- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} " ,
337
- format_speedup (vs_deepgemm ),
338
- format_speedup (vs_triton )
339
- ])
346
+ cutlass_rows .append (
347
+ [
348
+ shape ["m" ],
349
+ shape ["n" ],
350
+ shape ["k" ],
351
+ f"{ impl_data ['time_us' ]:.1f} " ,
352
+ f"{ impl_data ['tflops' ]:.1f} " ,
353
+ f"{ impl_data ['gb_s' ]:.1f} " ,
354
+ format_speedup (vs_deepgemm ),
355
+ format_speedup (vs_triton ),
356
+ ]
357
+ )
340
358
341
- print_table (cutlass_headers ,
342
- cutlass_rows ,
343
- title = "vLLM CUTLASS Implementation:" )
359
+ print_table (cutlass_headers , cutlass_rows , title = "vLLM CUTLASS Implementation:" )
344
360
345
361
# Calculate and print averages
346
362
print ("\n ===== AVERAGE PERFORMANCE =====" )
347
363
348
364
implementations = ["DeepGEMM" , "vLLM Triton" , "vLLM CUTLASS" ]
349
365
avg_metrics = {
350
- impl : {
351
- "tflops" : 0 ,
352
- "gb_s" : 0 ,
353
- "time_ms" : 0
354
- }
355
- for impl in implementations
366
+ impl : {"tflops" : 0 , "gb_s" : 0 , "time_ms" : 0 } for impl in implementations
356
367
}
357
368
358
369
for result in all_results :
@@ -370,31 +381,29 @@ def run_benchmarks(verbose: bool = False):
370
381
avg_tflops = avg_metrics [impl ]["tflops" ] / num_shapes
371
382
avg_mem_bw = avg_metrics [impl ]["gb_s" ] / num_shapes
372
383
avg_time = avg_metrics [impl ]["time_ms" ] / num_shapes
373
- avg_rows .append ([
374
- impl , f"{ avg_tflops :.2f} " , f"{ avg_mem_bw :.2f} " , f"{ avg_time :.2f} "
375
- ] )
384
+ avg_rows .append (
385
+ [ impl , f"{ avg_tflops :.2f} " , f"{ avg_mem_bw :.2f} " , f"{ avg_time :.2f} " ]
386
+ )
376
387
377
388
print_table (avg_headers , avg_rows )
378
389
379
390
# Calculate average speedups
380
391
avg_speedups = {
381
392
"DeepGEMM vs vLLM Triton" : 0 ,
382
393
"DeepGEMM vs vLLM CUTLASS" : 0 ,
383
- "vLLM CUTLASS vs vLLM Triton" : 0
394
+ "vLLM CUTLASS vs vLLM Triton" : 0 ,
384
395
}
385
396
386
397
for result in all_results :
387
398
deepgemm_time = result ["implementations" ]["DeepGEMM" ]["time_ms" ]
388
399
vllm_triton_time = result ["implementations" ]["vLLM Triton" ]["time_ms" ]
389
- vllm_cutlass_time = result ["implementations" ]["vLLM CUTLASS" ][
390
- "time_ms" ]
400
+ vllm_cutlass_time = result ["implementations" ]["vLLM CUTLASS" ]["time_ms" ]
391
401
392
- avg_speedups [
393
- "DeepGEMM vs vLLM Triton" ] += vllm_triton_time / deepgemm_time
394
- avg_speedups [
395
- "DeepGEMM vs vLLM CUTLASS" ] += vllm_cutlass_time / deepgemm_time
396
- avg_speedups [
397
- "vLLM CUTLASS vs vLLM Triton" ] += vllm_triton_time / vllm_cutlass_time
402
+ avg_speedups ["DeepGEMM vs vLLM Triton" ] += vllm_triton_time / deepgemm_time
403
+ avg_speedups ["DeepGEMM vs vLLM CUTLASS" ] += vllm_cutlass_time / deepgemm_time
404
+ avg_speedups ["vLLM CUTLASS vs vLLM Triton" ] += (
405
+ vllm_triton_time / vllm_cutlass_time
406
+ )
398
407
399
408
print ("\n ===== AVERAGE SPEEDUPS =====" )
400
409
speedup_headers = ["Comparison" , "Speedup" ]
@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
412
421
413
422
for result in all_results :
414
423
for impl in implementations :
415
- avg_diff [impl ] += result ["implementations" ][impl ]["diff" ][
416
- "Reference" ]
424
+ avg_diff [impl ] += result ["implementations" ][impl ]["diff" ]["Reference" ]
417
425
418
426
diff_headers = ["Implementation" , "Avg Diff vs Reference" ]
419
427
diff_rows = []
0 commit comments