Skip to content

Commit 19bee6d

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranathtlrmchlsmth
authored
[Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent dd5d1ef commit 19bee6d

File tree

4 files changed

+496
-81
lines changed

4 files changed

+496
-81
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from dataclasses import dataclass
5+
from enum import Enum
6+
from itertools import product
7+
from typing import Any
8+
9+
import torch
10+
import torch.utils.benchmark as TBenchmark
11+
from torch.utils.benchmark import Measurement as TMeasurement
12+
13+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14+
_per_token_group_quant_fp8_colmajor,
15+
silu_mul_per_token_group_quant_fp8_colmajor,
16+
)
17+
from vllm.triton_utils import triton
18+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
19+
20+
from .utils import ArgPool, Bench, CudaGraphBenchParams
21+
22+
GROUP_SIZE = 128
23+
FLOAT8_T = torch.float8_e4m3fn
24+
25+
26+
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
27+
print(
28+
f"Note : The timings reported above is for {cuda_graph_nops} "
29+
"consecutive invocations of the benchmarking functions. "
30+
f"Please divide by {cuda_graph_nops} for single invocation "
31+
"timings."
32+
)
33+
compare = TBenchmark.Compare(timers)
34+
compare.print()
35+
36+
37+
class ImplType(Enum):
38+
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
39+
REFERENCE = 2
40+
41+
def get_impl(self):
42+
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
43+
return silu_mul_per_token_group_quant_fp8_colmajor
44+
elif self == ImplType.REFERENCE:
45+
return reference
46+
raise ValueError(f"Unrecognized ImplType {self}")
47+
48+
49+
@dataclass
50+
class BenchmarkTensors:
51+
input: torch.Tensor
52+
output: torch.Tensor
53+
54+
# Reference act output tensor
55+
ref_act_out: torch.Tensor
56+
ref_quant_out: torch.Tensor
57+
58+
@staticmethod
59+
def make(T: int, N: int) -> "BenchmarkTensors":
60+
assert T % GROUP_SIZE == 0
61+
assert N % (GROUP_SIZE * 2) == 0
62+
63+
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
64+
65+
# silu_mul_per_token_group_quant_fp8_colmajor output.
66+
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
67+
FLOAT8_T
68+
)
69+
70+
# reference output.
71+
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
72+
ref_quant_out = torch.empty(
73+
(T, N // 2), dtype=torch.bfloat16, device="cuda"
74+
).to(FLOAT8_T)
75+
76+
return BenchmarkTensors(
77+
input=input,
78+
output=output,
79+
ref_act_out=ref_act_out,
80+
ref_quant_out=ref_quant_out,
81+
)
82+
83+
@property
84+
def T(self):
85+
return self.input.size(0)
86+
87+
@property
88+
def N(self):
89+
return self.input.size(1)
90+
91+
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
92+
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
93+
return {
94+
"input": self.input,
95+
"output": self.output,
96+
"use_ue8m0": is_deep_gemm_e8m0_used(),
97+
}
98+
elif impl_type == ImplType.REFERENCE:
99+
return {
100+
"input": self.input,
101+
"act_out": self.ref_act_out,
102+
"quant_out": self.ref_quant_out,
103+
"use_ue8m0": is_deep_gemm_e8m0_used(),
104+
}
105+
raise ValueError(f"Unrecognized impl_type {impl_type}")
106+
107+
108+
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
109+
"""
110+
Reference triton quant kernel from,
111+
vllm.model_executor.layers.quantization.utils.fp8_utils
112+
"""
113+
assert quant_out.size() == x.size()
114+
# Allocate the scale tensor column-major format.
115+
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
116+
x_q = quant_out
117+
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
118+
119+
M = x.numel() // GROUP_SIZE
120+
N = GROUP_SIZE
121+
BLOCK = triton.next_power_of_2(N)
122+
# heuristics for number of warps
123+
num_warps = min(max(BLOCK // 256, 1), 8)
124+
num_stages = 1
125+
126+
finfo = torch.finfo(FLOAT8_T)
127+
fp8_min = finfo.min
128+
fp8_max = finfo.max
129+
130+
_per_token_group_quant_fp8_colmajor[(M,)](
131+
x,
132+
x_q,
133+
x_s,
134+
GROUP_SIZE,
135+
x.shape[1],
136+
x.stride(0),
137+
x_s.stride(1),
138+
eps=1e-10,
139+
fp8_min=fp8_min,
140+
fp8_max=fp8_max,
141+
use_ue8m0=use_ue8m0,
142+
BLOCK=BLOCK,
143+
num_warps=num_warps,
144+
num_stages=num_stages,
145+
)
146+
return x_q, x_s
147+
148+
149+
def reference(
150+
input: torch.Tensor,
151+
act_out: torch.Tensor,
152+
quant_out: torch.Tensor,
153+
use_ue8m0: bool,
154+
) -> tuple[torch.Tensor, torch.Tensor]:
155+
torch.ops._C.silu_and_mul(act_out, input)
156+
return reference_quant(act_out, quant_out, use_ue8m0)
157+
158+
159+
def bench_impl(
160+
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
161+
) -> TMeasurement:
162+
T = bench_tensors[0].T
163+
N = bench_tensors[0].N
164+
165+
arg_pool_size = len(bench_tensors)
166+
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
167+
168+
# warmup
169+
for kwargs in kwargs_list:
170+
impl_type.get_impl()(**kwargs)
171+
torch.cuda.synchronize()
172+
173+
# Merge into a single kwargs and qualify arguments as ArgPool
174+
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
175+
for _kwargs in kwargs_list:
176+
for k, v in _kwargs.items():
177+
kwargs[k].values.append(v)
178+
179+
cuda_graph_params = None
180+
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
181+
timer = None
182+
with Bench(
183+
cuda_graph_params,
184+
"silu-mul-quant",
185+
f"num_tokens={T}, N={N}",
186+
impl_type.name,
187+
impl_type.get_impl(),
188+
**kwargs,
189+
) as bench:
190+
timer = bench.run()
191+
return timer
192+
193+
194+
def test_correctness(T: int, N: int):
195+
print(f"Testing num_tokens={T}, N={N} ...")
196+
197+
bench_tensor = BenchmarkTensors.make(T, N)
198+
199+
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
200+
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
201+
202+
# reference output
203+
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
204+
205+
# test ouptut
206+
out_q, out_s = output_from_impl(
207+
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
208+
)
209+
210+
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
211+
torch.testing.assert_close(ref_out_s, out_s)
212+
213+
214+
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
215+
timers = []
216+
for N, T in product(Ns, Ts):
217+
test_correctness(T, N)
218+
219+
bench_tensors: list[BenchmarkTensors] = [
220+
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
221+
]
222+
223+
silu_mul_quant_timer = bench_impl(
224+
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
225+
)
226+
timers.append(silu_mul_quant_timer)
227+
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
228+
timers.append(reference_timer)
229+
230+
print_timers(
231+
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
232+
)
233+
234+
print_timers(timers, cuda_graph_nops=arg_pool_size)
235+
236+
return timers
237+
238+
239+
if __name__ == "__main__":
240+
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
241+
N = [2048, 4096, 8192]
242+
243+
print(f"T = {T}, N = {N}")
244+
run(T, N, arg_pool_size=8)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
8+
_per_token_group_quant_fp8_colmajor,
9+
silu_mul_per_token_group_quant_fp8_colmajor,
10+
)
11+
from vllm.platforms import current_platform
12+
from vllm.triton_utils import triton
13+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
14+
15+
FLOAT8_DTYPE = torch.float8_e4m3fn
16+
GROUP_SIZE = 128
17+
18+
19+
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
20+
"""
21+
Reference triton quant kernel from,
22+
vllm.model_executor.layers.quantization.utils.fp8_utils
23+
"""
24+
25+
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
26+
27+
# Allocate the scale tensor in column-major format.
28+
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
29+
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
30+
31+
M = x.numel() // GROUP_SIZE
32+
N = GROUP_SIZE
33+
BLOCK = triton.next_power_of_2(N)
34+
# heuristics for number of warps
35+
num_warps = min(max(BLOCK // 256, 1), 8)
36+
num_stages = 1
37+
38+
finfo = torch.finfo(FLOAT8_DTYPE)
39+
fp8_min = finfo.min
40+
fp8_max = finfo.max
41+
42+
_per_token_group_quant_fp8_colmajor[(M,)](
43+
x,
44+
x_q,
45+
x_s,
46+
GROUP_SIZE,
47+
x.shape[1],
48+
x.stride(0),
49+
x_s.stride(1),
50+
eps=1e-10,
51+
fp8_min=fp8_min,
52+
fp8_max=fp8_max,
53+
use_ue8m0=use_ue8m0,
54+
BLOCK=BLOCK,
55+
num_warps=num_warps,
56+
num_stages=num_stages,
57+
)
58+
return x_q, x_s
59+
60+
61+
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
62+
T, N = x.size()
63+
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
64+
torch.ops._C.silu_and_mul(ref_act_out, x)
65+
return reference_quant(ref_act_out, use_ue8m0)
66+
67+
68+
@pytest.mark.parametrize("T", [128, 256, 512])
69+
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
70+
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
71+
current_platform.seed_everything(42)
72+
73+
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
74+
75+
use_ue8m0 = is_deep_gemm_e8m0_used()
76+
77+
# Test
78+
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
79+
input, use_ue8m0=use_ue8m0
80+
)
81+
82+
# Reference
83+
ref_output, ref_output_scales = reference(input, use_ue8m0)
84+
85+
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
86+
torch.testing.assert_close(output_scales, ref_output_scales)

0 commit comments

Comments
 (0)