Skip to content

Commit 557b2e9

Browse files
authored
Remove all cases of fmt: on/off (#26253)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 4e256ca commit 557b2e9

File tree

5 files changed

+217
-157
lines changed

5 files changed

+217
-157
lines changed

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 117 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
# fmt: off
43
# ruff: noqa: E501
54
import time
65

@@ -20,19 +19,21 @@
2019
)
2120

2221

23-
def benchmark_shape(m: int,
24-
n: int,
25-
k: int,
26-
warmup: int = 100,
27-
repeat: int = 10000,
28-
verbose: bool = False) -> dict:
22+
def benchmark_shape(
23+
m: int,
24+
n: int,
25+
k: int,
26+
warmup: int = 100,
27+
repeat: int = 10000,
28+
verbose: bool = False,
29+
) -> dict:
2930
"""Benchmark all implementations for a specific (m, n, k) shape."""
3031
if verbose:
3132
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
3233

3334
# Create test tensors
34-
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
35-
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
35+
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
36+
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
3637

3738
# Reference result in BF16
3839
torch.cuda.synchronize()
@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
4950
# Pre-quantize A for all implementations
5051
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
5152
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
52-
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
53+
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
5354
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
5455
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
55-
A, block_size[1], column_major_scales=True)
56+
A, block_size[1], column_major_scales=True
57+
)
5658

5759
# === DeepGEMM Implementation ===
5860
def deepgemm_gemm():
59-
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
60-
(B_deepgemm, B_scale_deepgemm),
61-
C_deepgemm)
61+
fp8_gemm_nt(
62+
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
63+
)
6264
return C_deepgemm
6365

6466
# === vLLM Triton Implementation ===
6567
def vllm_triton_gemm():
66-
return w8a8_triton_block_scaled_mm(A_vllm,
67-
B_vllm,
68-
A_scale_vllm,
69-
B_scale_vllm,
70-
block_size,
71-
output_dtype=torch.bfloat16)
68+
return w8a8_triton_block_scaled_mm(
69+
A_vllm,
70+
B_vllm,
71+
A_scale_vllm,
72+
B_scale_vllm,
73+
block_size,
74+
output_dtype=torch.bfloat16,
75+
)
7276

7377
# === vLLM CUTLASS Implementation ===
7478
def vllm_cutlass_gemm():
75-
return ops.cutlass_scaled_mm(A_vllm_cutlass,
76-
B_vllm.T,
77-
scale_a=A_scale_vllm_cutlass,
78-
scale_b=B_scale_vllm.T,
79-
out_dtype=torch.bfloat16)
79+
return ops.cutlass_scaled_mm(
80+
A_vllm_cutlass,
81+
B_vllm.T,
82+
scale_a=A_scale_vllm_cutlass,
83+
scale_b=B_scale_vllm.T,
84+
out_dtype=torch.bfloat16,
85+
)
8086

8187
# Run correctness check first
8288
if verbose:
@@ -93,26 +99,23 @@ def vllm_cutlass_gemm():
9399
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
94100
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
95101
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
96-
print("vLLM Triton vs DeepGEMM difference: "
97-
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
98-
print("vLLM CUTLASS vs DeepGEMM difference: "
99-
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
102+
print(
103+
"vLLM Triton vs DeepGEMM difference: "
104+
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
105+
)
106+
print(
107+
"vLLM CUTLASS vs DeepGEMM difference: "
108+
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
109+
)
100110

101111
# Benchmark implementations
102112
implementations = {
103113
"DeepGEMM": deepgemm_gemm,
104114
"vLLM Triton": vllm_triton_gemm,
105-
"vLLM CUTLASS": vllm_cutlass_gemm
115+
"vLLM CUTLASS": vllm_cutlass_gemm,
106116
}
107117

108-
benchmark_results = {
109-
"shape": {
110-
"m": m,
111-
"n": n,
112-
"k": k
113-
},
114-
"implementations": {}
115-
}
118+
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
116119

117120
for name, func in implementations.items():
118121
# Warmup
@@ -140,38 +143,36 @@ def vllm_cutlass_gemm():
140143
"tflops": tflops,
141144
"gb_s": gb_s,
142145
"diff": {
143-
"DeepGEMM":
144-
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
145-
"Reference":
146-
deepgemm_diff if name == "DeepGEMM" else
147-
(vllm_triton_diff
148-
if name == "vLLM Triton" else vllm_cutlass_diff)
149-
}
146+
"DeepGEMM": 0.0
147+
if name == "DeepGEMM"
148+
else calc_diff(func(), C_deepgemm),
149+
"Reference": deepgemm_diff
150+
if name == "DeepGEMM"
151+
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
152+
},
150153
}
151154

152155
if verbose:
153-
print(
154-
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
155-
)
156+
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
156157

157158
# Calculate speedups
158159
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
159160
for name, data in benchmark_results["implementations"].items():
160161
if name != "DeepGEMM":
161162
speedup = baseline / data["time_ms"]
162-
benchmark_results["implementations"][name][
163-
"speedup_vs_deepgemm"] = speedup
163+
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
164164
if verbose:
165-
print(f"DeepGEMM is {1/speedup:.2f}x "
166-
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
165+
print(
166+
f"DeepGEMM is {1 / speedup:.2f}x "
167+
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
168+
)
167169

168-
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
169-
"time_ms"]
170-
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
171-
"time_ms"]
170+
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
171+
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
172172
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
173-
benchmark_results["implementations"]["vLLM CUTLASS"][
174-
"speedup_vs_triton"] = cutlass_vs_triton
173+
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
174+
cutlass_vs_triton
175+
)
175176
if verbose:
176177
print(
177178
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
@@ -183,8 +184,7 @@ def vllm_cutlass_gemm():
183184

184185
def format_table_row(values, widths):
185186
"""Format a row with specified column widths."""
186-
return "| " + " | ".join(f"{val:{w}}"
187-
for val, w in zip(values, widths)) + " |"
187+
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
188188

189189

190190
def print_table(headers, rows, title=None):
@@ -292,67 +292,78 @@ def run_benchmarks(verbose: bool = False):
292292
for result in all_results:
293293
shape = result["shape"]
294294
impl_data = result["implementations"]["DeepGEMM"]
295-
deepgemm_rows.append([
296-
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
297-
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
298-
])
295+
deepgemm_rows.append(
296+
[
297+
shape["m"],
298+
shape["n"],
299+
shape["k"],
300+
f"{impl_data['time_us']:.1f}",
301+
f"{impl_data['tflops']:.1f}",
302+
f"{impl_data['gb_s']:.1f}",
303+
]
304+
)
299305

300-
print_table(deepgemm_headers,
301-
deepgemm_rows,
302-
title="DeepGEMM Implementation:")
306+
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
303307

304308
# Print vLLM Triton table
305-
triton_headers = [
306-
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
307-
]
309+
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
308310
triton_rows = []
309311
for result in all_results:
310312
shape = result["shape"]
311313
impl_data = result["implementations"]["vLLM Triton"]
312314
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
313-
triton_rows.append([
314-
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
315-
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
316-
format_speedup(speedup)
317-
])
315+
triton_rows.append(
316+
[
317+
shape["m"],
318+
shape["n"],
319+
shape["k"],
320+
f"{impl_data['time_us']:.1f}",
321+
f"{impl_data['tflops']:.1f}",
322+
f"{impl_data['gb_s']:.1f}",
323+
format_speedup(speedup),
324+
]
325+
)
318326

319-
print_table(triton_headers,
320-
triton_rows,
321-
title="vLLM Triton Implementation:")
327+
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
322328

323329
# Print vLLM CUTLASS table
324330
cutlass_headers = [
325-
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
326-
"vs Triton"
331+
"m",
332+
"n",
333+
"k",
334+
"Time (μs)",
335+
"TFLOPS",
336+
"GB/s",
337+
"vs DeepGEMM",
338+
"vs Triton",
327339
]
328340
cutlass_rows = []
329341
for result in all_results:
330342
shape = result["shape"]
331343
impl_data = result["implementations"]["vLLM CUTLASS"]
332344
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
333345
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
334-
cutlass_rows.append([
335-
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
336-
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
337-
format_speedup(vs_deepgemm),
338-
format_speedup(vs_triton)
339-
])
346+
cutlass_rows.append(
347+
[
348+
shape["m"],
349+
shape["n"],
350+
shape["k"],
351+
f"{impl_data['time_us']:.1f}",
352+
f"{impl_data['tflops']:.1f}",
353+
f"{impl_data['gb_s']:.1f}",
354+
format_speedup(vs_deepgemm),
355+
format_speedup(vs_triton),
356+
]
357+
)
340358

341-
print_table(cutlass_headers,
342-
cutlass_rows,
343-
title="vLLM CUTLASS Implementation:")
359+
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
344360

345361
# Calculate and print averages
346362
print("\n===== AVERAGE PERFORMANCE =====")
347363

348364
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
349365
avg_metrics = {
350-
impl: {
351-
"tflops": 0,
352-
"gb_s": 0,
353-
"time_ms": 0
354-
}
355-
for impl in implementations
366+
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
356367
}
357368

358369
for result in all_results:
@@ -370,31 +381,29 @@ def run_benchmarks(verbose: bool = False):
370381
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
371382
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
372383
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
373-
avg_rows.append([
374-
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
375-
])
384+
avg_rows.append(
385+
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
386+
)
376387

377388
print_table(avg_headers, avg_rows)
378389

379390
# Calculate average speedups
380391
avg_speedups = {
381392
"DeepGEMM vs vLLM Triton": 0,
382393
"DeepGEMM vs vLLM CUTLASS": 0,
383-
"vLLM CUTLASS vs vLLM Triton": 0
394+
"vLLM CUTLASS vs vLLM Triton": 0,
384395
}
385396

386397
for result in all_results:
387398
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
388399
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
389-
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
390-
"time_ms"]
400+
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
391401

392-
avg_speedups[
393-
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
394-
avg_speedups[
395-
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
396-
avg_speedups[
397-
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
402+
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
403+
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
404+
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
405+
vllm_triton_time / vllm_cutlass_time
406+
)
398407

399408
print("\n===== AVERAGE SPEEDUPS =====")
400409
speedup_headers = ["Comparison", "Speedup"]
@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
412421

413422
for result in all_results:
414423
for impl in implementations:
415-
avg_diff[impl] += result["implementations"][impl]["diff"][
416-
"Reference"]
424+
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
417425

418426
diff_headers = ["Implementation", "Avg Diff vs Reference"]
419427
diff_rows = []

0 commit comments

Comments
 (0)