Skip to content

Commit 0271c2f

Browse files
authored
[Test] Add Benchmark and Unit Test for per_token_group_quant (#21860)
Signed-off-by: yewentao256 <[email protected]>
1 parent e91d3c9 commit 0271c2f

File tree

2 files changed

+189
-1
lines changed

2 files changed

+189
-1
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import argparse
5+
import math
6+
from contextlib import contextmanager
7+
from typing import Callable
8+
from unittest.mock import patch
9+
10+
import torch
11+
12+
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
13+
from vllm.platforms import current_platform
14+
15+
16+
@contextmanager
17+
def _triton_mode():
18+
"""Temporarily force the Triton fallback path"""
19+
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
20+
yield
21+
22+
23+
def _time_cuda(
24+
fn: Callable[[], tuple[torch.Tensor, torch.Tensor]],
25+
warmup_iters: int,
26+
bench_iters: int,
27+
) -> float:
28+
# warmup
29+
for _ in range(warmup_iters):
30+
fn()
31+
torch.cuda.synchronize()
32+
33+
start = torch.cuda.Event(enable_timing=True)
34+
end = torch.cuda.Event(enable_timing=True)
35+
36+
start.record()
37+
for _ in range(bench_iters):
38+
fn()
39+
end.record()
40+
torch.cuda.synchronize()
41+
42+
return start.elapsed_time(end) / bench_iters # ms/iter
43+
44+
45+
def _run_single(
46+
shape: tuple[int, int],
47+
group_size: int,
48+
dtype: str,
49+
*,
50+
column_major: bool = False,
51+
scale_ue8m0: bool = False,
52+
warmup_iters: int,
53+
bench_iters: int,
54+
) -> None:
55+
num_tokens, hidden_dim = shape
56+
57+
device = torch.device("cuda")
58+
torch.manual_seed(42)
59+
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8
60+
61+
if dtype == "fp8":
62+
63+
def cuda_impl():
64+
return fp8_utils.per_token_group_quant_fp8(
65+
x,
66+
group_size,
67+
column_major_scales=column_major,
68+
use_ue8m0=scale_ue8m0,
69+
)
70+
71+
def triton_impl():
72+
with _triton_mode():
73+
return fp8_utils.per_token_group_quant_fp8(
74+
x,
75+
group_size,
76+
column_major_scales=column_major,
77+
use_ue8m0=scale_ue8m0,
78+
)
79+
elif dtype == "int8":
80+
81+
def cuda_impl():
82+
return int8_utils.per_token_group_quant_int8(x, group_size)
83+
84+
def triton_impl():
85+
with _triton_mode():
86+
return int8_utils.per_token_group_quant_int8(x, group_size)
87+
else:
88+
raise ValueError("dtype must be 'fp8' or 'int8'")
89+
90+
cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
91+
triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters)
92+
93+
speedup = triton_ms / cuda_ms if cuda_ms else math.inf
94+
95+
cfg_desc = (
96+
f"shape={shape} gs={group_size:<3} col_major={column_major:<5} "
97+
f"ue8m0={scale_ue8m0:<5} dtype={dtype}"
98+
)
99+
print(
100+
f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | "
101+
f"speed-up ×{speedup:5.2f}"
102+
)
103+
104+
105+
def parse_args():
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument("--warmup-iters", type=int, default=10)
108+
parser.add_argument("--bench-iters", type=int, default=100)
109+
parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both")
110+
return parser.parse_args()
111+
112+
113+
if __name__ == "__main__":
114+
if not current_platform.is_cuda():
115+
raise RuntimeError("CUDA device is required to run this benchmark.")
116+
117+
args = parse_args()
118+
warmup_iters, bench_iters = args.warmup_iters, args.bench_iters
119+
120+
shapes = [(32, 128), (64, 256), (16, 512)]
121+
group_sizes = [64, 128]
122+
123+
dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype]
124+
125+
header = (
126+
"Configuration".ljust(55)
127+
+ " | "
128+
+ "CUDA (ms)".center(12)
129+
+ " | "
130+
+ "Triton (ms)".center(13)
131+
+ " | "
132+
+ "Speed-up"
133+
)
134+
print(header)
135+
print("-" * len(header))
136+
137+
for dtype in dtypes:
138+
for shape in shapes:
139+
for gs in group_sizes:
140+
if dtype == "fp8":
141+
for col_major in (False, True):
142+
for ue8m0 in (False, True):
143+
_run_single(
144+
shape,
145+
gs,
146+
dtype,
147+
column_major=col_major,
148+
scale_ue8m0=ue8m0,
149+
warmup_iters=warmup_iters,
150+
bench_iters=bench_iters,
151+
)
152+
else: # INT8 has no col-major / ue8m0 switches
153+
_run_single(
154+
shape,
155+
gs,
156+
dtype,
157+
warmup_iters=warmup_iters,
158+
bench_iters=bench_iters,
159+
)

tests/kernels/quantization/test_per_token_group_quant.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import torch
77

8-
from vllm.model_executor.layers.quantization.utils import fp8_utils
8+
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
99

1010

1111
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
@@ -42,3 +42,32 @@ def test_per_token_group_quant_fp8(shape, column_major: bool,
4242

4343
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
4444
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
45+
46+
47+
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
48+
@pytest.mark.parametrize("group_size", [64, 128])
49+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
50+
def test_per_token_group_quant_int8(shape, group_size: int):
51+
device = "cuda"
52+
53+
torch.manual_seed(42)
54+
num_tokens, hidden_dim = shape
55+
56+
x = (torch.randn(
57+
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
58+
59+
# cuda path
60+
out_q, scale = int8_utils.per_token_group_quant_int8(
61+
x,
62+
group_size,
63+
)
64+
65+
# triton ref
66+
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
67+
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
68+
x,
69+
group_size,
70+
)
71+
72+
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
73+
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)

0 commit comments

Comments
 (0)