|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from dataclasses import dataclass |
| 5 | +from enum import Enum |
| 6 | +from itertools import product |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.utils.benchmark as TBenchmark |
| 11 | +from torch.utils.benchmark import Measurement as TMeasurement |
| 12 | + |
| 13 | +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
| 14 | + _per_token_group_quant_fp8_colmajor, |
| 15 | + silu_mul_per_token_group_quant_fp8_colmajor, |
| 16 | +) |
| 17 | +from vllm.triton_utils import triton |
| 18 | +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used |
| 19 | + |
| 20 | +from .utils import ArgPool, Bench, CudaGraphBenchParams |
| 21 | + |
| 22 | +GROUP_SIZE = 128 |
| 23 | +FLOAT8_T = torch.float8_e4m3fn |
| 24 | + |
| 25 | + |
| 26 | +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): |
| 27 | + print( |
| 28 | + f"Note : The timings reported above is for {cuda_graph_nops} " |
| 29 | + "consecutive invocations of the benchmarking functions. " |
| 30 | + f"Please divide by {cuda_graph_nops} for single invocation " |
| 31 | + "timings." |
| 32 | + ) |
| 33 | + compare = TBenchmark.Compare(timers) |
| 34 | + compare.print() |
| 35 | + |
| 36 | + |
| 37 | +class ImplType(Enum): |
| 38 | + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 |
| 39 | + REFERENCE = 2 |
| 40 | + |
| 41 | + def get_impl(self): |
| 42 | + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: |
| 43 | + return silu_mul_per_token_group_quant_fp8_colmajor |
| 44 | + elif self == ImplType.REFERENCE: |
| 45 | + return reference |
| 46 | + raise ValueError(f"Unrecognized ImplType {self}") |
| 47 | + |
| 48 | + |
| 49 | +@dataclass |
| 50 | +class BenchmarkTensors: |
| 51 | + input: torch.Tensor |
| 52 | + output: torch.Tensor |
| 53 | + |
| 54 | + # Reference act output tensor |
| 55 | + ref_act_out: torch.Tensor |
| 56 | + ref_quant_out: torch.Tensor |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def make(T: int, N: int) -> "BenchmarkTensors": |
| 60 | + assert T % GROUP_SIZE == 0 |
| 61 | + assert N % (GROUP_SIZE * 2) == 0 |
| 62 | + |
| 63 | + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") |
| 64 | + |
| 65 | + # silu_mul_per_token_group_quant_fp8_colmajor output. |
| 66 | + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( |
| 67 | + FLOAT8_T |
| 68 | + ) |
| 69 | + |
| 70 | + # reference output. |
| 71 | + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") |
| 72 | + ref_quant_out = torch.empty( |
| 73 | + (T, N // 2), dtype=torch.bfloat16, device="cuda" |
| 74 | + ).to(FLOAT8_T) |
| 75 | + |
| 76 | + return BenchmarkTensors( |
| 77 | + input=input, |
| 78 | + output=output, |
| 79 | + ref_act_out=ref_act_out, |
| 80 | + ref_quant_out=ref_quant_out, |
| 81 | + ) |
| 82 | + |
| 83 | + @property |
| 84 | + def T(self): |
| 85 | + return self.input.size(0) |
| 86 | + |
| 87 | + @property |
| 88 | + def N(self): |
| 89 | + return self.input.size(1) |
| 90 | + |
| 91 | + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: |
| 92 | + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: |
| 93 | + return { |
| 94 | + "input": self.input, |
| 95 | + "output": self.output, |
| 96 | + "use_ue8m0": is_deep_gemm_e8m0_used(), |
| 97 | + } |
| 98 | + elif impl_type == ImplType.REFERENCE: |
| 99 | + return { |
| 100 | + "input": self.input, |
| 101 | + "act_out": self.ref_act_out, |
| 102 | + "quant_out": self.ref_quant_out, |
| 103 | + "use_ue8m0": is_deep_gemm_e8m0_used(), |
| 104 | + } |
| 105 | + raise ValueError(f"Unrecognized impl_type {impl_type}") |
| 106 | + |
| 107 | + |
| 108 | +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): |
| 109 | + """ |
| 110 | + Reference triton quant kernel from, |
| 111 | + vllm.model_executor.layers.quantization.utils.fp8_utils |
| 112 | + """ |
| 113 | + assert quant_out.size() == x.size() |
| 114 | + # Allocate the scale tensor column-major format. |
| 115 | + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] |
| 116 | + x_q = quant_out |
| 117 | + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) |
| 118 | + |
| 119 | + M = x.numel() // GROUP_SIZE |
| 120 | + N = GROUP_SIZE |
| 121 | + BLOCK = triton.next_power_of_2(N) |
| 122 | + # heuristics for number of warps |
| 123 | + num_warps = min(max(BLOCK // 256, 1), 8) |
| 124 | + num_stages = 1 |
| 125 | + |
| 126 | + finfo = torch.finfo(FLOAT8_T) |
| 127 | + fp8_min = finfo.min |
| 128 | + fp8_max = finfo.max |
| 129 | + |
| 130 | + _per_token_group_quant_fp8_colmajor[(M,)]( |
| 131 | + x, |
| 132 | + x_q, |
| 133 | + x_s, |
| 134 | + GROUP_SIZE, |
| 135 | + x.shape[1], |
| 136 | + x.stride(0), |
| 137 | + x_s.stride(1), |
| 138 | + eps=1e-10, |
| 139 | + fp8_min=fp8_min, |
| 140 | + fp8_max=fp8_max, |
| 141 | + use_ue8m0=use_ue8m0, |
| 142 | + BLOCK=BLOCK, |
| 143 | + num_warps=num_warps, |
| 144 | + num_stages=num_stages, |
| 145 | + ) |
| 146 | + return x_q, x_s |
| 147 | + |
| 148 | + |
| 149 | +def reference( |
| 150 | + input: torch.Tensor, |
| 151 | + act_out: torch.Tensor, |
| 152 | + quant_out: torch.Tensor, |
| 153 | + use_ue8m0: bool, |
| 154 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 155 | + torch.ops._C.silu_and_mul(act_out, input) |
| 156 | + return reference_quant(act_out, quant_out, use_ue8m0) |
| 157 | + |
| 158 | + |
| 159 | +def bench_impl( |
| 160 | + bench_tensors: list[BenchmarkTensors], impl_type: ImplType |
| 161 | +) -> TMeasurement: |
| 162 | + T = bench_tensors[0].T |
| 163 | + N = bench_tensors[0].N |
| 164 | + |
| 165 | + arg_pool_size = len(bench_tensors) |
| 166 | + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] |
| 167 | + |
| 168 | + # warmup |
| 169 | + for kwargs in kwargs_list: |
| 170 | + impl_type.get_impl()(**kwargs) |
| 171 | + torch.cuda.synchronize() |
| 172 | + |
| 173 | + # Merge into a single kwargs and qualify arguments as ArgPool |
| 174 | + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} |
| 175 | + for _kwargs in kwargs_list: |
| 176 | + for k, v in _kwargs.items(): |
| 177 | + kwargs[k].values.append(v) |
| 178 | + |
| 179 | + cuda_graph_params = None |
| 180 | + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) |
| 181 | + timer = None |
| 182 | + with Bench( |
| 183 | + cuda_graph_params, |
| 184 | + "silu-mul-quant", |
| 185 | + f"num_tokens={T}, N={N}", |
| 186 | + impl_type.name, |
| 187 | + impl_type.get_impl(), |
| 188 | + **kwargs, |
| 189 | + ) as bench: |
| 190 | + timer = bench.run() |
| 191 | + return timer |
| 192 | + |
| 193 | + |
| 194 | +def test_correctness(T: int, N: int): |
| 195 | + print(f"Testing num_tokens={T}, N={N} ...") |
| 196 | + |
| 197 | + bench_tensor = BenchmarkTensors.make(T, N) |
| 198 | + |
| 199 | + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: |
| 200 | + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) |
| 201 | + |
| 202 | + # reference output |
| 203 | + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) |
| 204 | + |
| 205 | + # test ouptut |
| 206 | + out_q, out_s = output_from_impl( |
| 207 | + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR |
| 208 | + ) |
| 209 | + |
| 210 | + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) |
| 211 | + torch.testing.assert_close(ref_out_s, out_s) |
| 212 | + |
| 213 | + |
| 214 | +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: |
| 215 | + timers = [] |
| 216 | + for N, T in product(Ns, Ts): |
| 217 | + test_correctness(T, N) |
| 218 | + |
| 219 | + bench_tensors: list[BenchmarkTensors] = [ |
| 220 | + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) |
| 221 | + ] |
| 222 | + |
| 223 | + silu_mul_quant_timer = bench_impl( |
| 224 | + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR |
| 225 | + ) |
| 226 | + timers.append(silu_mul_quant_timer) |
| 227 | + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) |
| 228 | + timers.append(reference_timer) |
| 229 | + |
| 230 | + print_timers( |
| 231 | + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size |
| 232 | + ) |
| 233 | + |
| 234 | + print_timers(timers, cuda_graph_nops=arg_pool_size) |
| 235 | + |
| 236 | + return timers |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == "__main__": |
| 240 | + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] |
| 241 | + N = [2048, 4096, 8192] |
| 242 | + |
| 243 | + print(f"T = {T}, N = {N}") |
| 244 | + run(T, N, arg_pool_size=8) |
0 commit comments