|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Warmup deep_gemm kernels. |
| 5 | +DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would |
| 6 | +be used during model execution beforehand. |
| 7 | +""" |
| 8 | + |
| 9 | +import torch |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +import vllm.envs as envs |
| 13 | +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts |
| 14 | +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( |
| 15 | + compute_aligned_M, deep_gemm_block_shape) |
| 16 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
| 17 | +from vllm.model_executor.layers.fused_moe.modular_kernel import ( |
| 18 | + FusedMoEModularKernel) |
| 19 | +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( |
| 20 | + TritonOrDeepGemmExperts) |
| 21 | +from vllm.model_executor.layers.linear import LinearBase |
| 22 | +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod |
| 23 | +from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous |
| 24 | + |
| 25 | + |
| 26 | +def _extract_data_from_linear_base_module( |
| 27 | + m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]: |
| 28 | + """ |
| 29 | + Extract weights, weight scales and quantization block sizes from the given |
| 30 | + LinearBase module. |
| 31 | + """ |
| 32 | + assert isinstance(m, LinearBase) |
| 33 | + assert isinstance(m.quant_method, Fp8LinearMethod) |
| 34 | + assert m.quant_method.block_quant |
| 35 | + assert m.quant_method.quant_config is not None |
| 36 | + |
| 37 | + w = m.weight |
| 38 | + ws = m.weight_scale_inv |
| 39 | + quant_block_size = m.quant_method.quant_config.weight_block_size |
| 40 | + |
| 41 | + assert isinstance(w, torch.Tensor) |
| 42 | + assert isinstance(ws, torch.Tensor) |
| 43 | + assert quant_block_size is not None |
| 44 | + return (w, ws, quant_block_size) |
| 45 | + |
| 46 | + |
| 47 | +def _extract_data_from_fused_moe_module( |
| 48 | + m: torch.nn.Module |
| 49 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: |
| 50 | + """ |
| 51 | + Extract weights, weight scales and num_topk from FusedMoE module. |
| 52 | + """ |
| 53 | + assert isinstance(m, FusedMoE) |
| 54 | + w13 = m.w13_weight |
| 55 | + w13_s = m.w13_weight_scale_inv |
| 56 | + w2 = m.w2_weight |
| 57 | + w2_s = m.w2_weight_scale_inv |
| 58 | + num_topk = m.top_k |
| 59 | + |
| 60 | + assert isinstance(w13, torch.Tensor) |
| 61 | + assert isinstance(w13_s, torch.Tensor) |
| 62 | + assert isinstance(w2, torch.Tensor) |
| 63 | + assert isinstance(w2_s, torch.Tensor) |
| 64 | + return w13, w13_s, w2, w2_s, num_topk |
| 65 | + |
| 66 | + |
| 67 | +def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: |
| 68 | + """ |
| 69 | + Return True if the input module/layer could be processed with DeepGEMM. |
| 70 | + """ |
| 71 | + block_size = deep_gemm_block_shape()[0] |
| 72 | + if not (isinstance(module, LinearBase) |
| 73 | + and isinstance(module.quant_method, Fp8LinearMethod) |
| 74 | + and module.quant_method.block_quant): |
| 75 | + return False |
| 76 | + |
| 77 | + w, _, block_sizes = _extract_data_from_linear_base_module(module) |
| 78 | + return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 |
| 79 | + and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) |
| 80 | + |
| 81 | + |
| 82 | +def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: |
| 83 | + if not (isinstance(module, FusedMoE) |
| 84 | + and module.moe_config.quant_dtype == torch.float8_e4m3fn |
| 85 | + and module.moe_config.block_shape == deep_gemm_block_shape()): |
| 86 | + return False |
| 87 | + |
| 88 | + if not isinstance(module.quant_method.fused_experts, |
| 89 | + FusedMoEModularKernel): |
| 90 | + # fused_experts could invoke deep_gemm_moe_fp8 |
| 91 | + return True |
| 92 | + |
| 93 | + mk: FusedMoEModularKernel = module.quant_method.fused_experts |
| 94 | + # Further check if the ModularKernel implementation uses the DeepGemmExperts |
| 95 | + return isinstance(mk.fused_experts, |
| 96 | + (DeepGemmExperts, TritonOrDeepGemmExperts)) |
| 97 | + |
| 98 | + |
| 99 | +FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() |
| 100 | + |
| 101 | + |
| 102 | +def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, |
| 103 | + max_tokens: int): |
| 104 | + if w.size() in FP8_GEMM_NT_WARMUP_CACHE: |
| 105 | + return |
| 106 | + |
| 107 | + n, k = w.size() |
| 108 | + block_m = deep_gemm_block_shape()[0] |
| 109 | + |
| 110 | + device = w.device |
| 111 | + a1q = torch.empty((max_tokens, k), |
| 112 | + device=device, |
| 113 | + dtype=torch.float8_e4m3fn) |
| 114 | + a1q_scales = torch.empty((max_tokens, k // block_m), |
| 115 | + device=device, |
| 116 | + dtype=torch.float32) |
| 117 | + out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) |
| 118 | + |
| 119 | + pbar = tqdm(total=max_tokens, |
| 120 | + desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") |
| 121 | + num_tokens = max_tokens |
| 122 | + while num_tokens > 0: |
| 123 | + fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), |
| 124 | + out[:num_tokens]) |
| 125 | + pbar.update(1) |
| 126 | + num_tokens -= 1 |
| 127 | + |
| 128 | + FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) |
| 129 | + |
| 130 | + |
| 131 | +GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() |
| 132 | + |
| 133 | + |
| 134 | +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, |
| 135 | + w2: torch.Tensor, |
| 136 | + w1_scale: torch.Tensor, |
| 137 | + w2_scale: torch.Tensor, |
| 138 | + num_topk: int): |
| 139 | + if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE |
| 140 | + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): |
| 141 | + return |
| 142 | + |
| 143 | + assert w1.size(0) == w2.size(0), ( |
| 144 | + "w1 and w2 must have the same number of experts") |
| 145 | + |
| 146 | + block_m = deep_gemm_block_shape()[0] |
| 147 | + num_experts = w1.size(0) |
| 148 | + device = w1.device |
| 149 | + |
| 150 | + # This is the maximum GroupedGemm M size that we expect to run |
| 151 | + # the grouped_gemm with. |
| 152 | + MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, |
| 153 | + num_topk, |
| 154 | + num_experts, |
| 155 | + block_m, |
| 156 | + expert_tokens_meta=None) |
| 157 | + # Distribute expert-ids evenly. |
| 158 | + MAX_BLOCKS = MAX_M // block_m |
| 159 | + expert_ids_block = torch.randint(low=0, |
| 160 | + high=num_experts, |
| 161 | + size=(MAX_BLOCKS, ), |
| 162 | + device=device, |
| 163 | + dtype=torch.int32) |
| 164 | + expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) |
| 165 | + |
| 166 | + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): |
| 167 | + |
| 168 | + _, n, k = w.size() |
| 169 | + a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) |
| 170 | + a1q_scales = torch.empty((MAX_M, k // block_m), |
| 171 | + device=device, |
| 172 | + dtype=torch.float32) |
| 173 | + out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) |
| 174 | + |
| 175 | + pbar = tqdm( |
| 176 | + total=MAX_BLOCKS, |
| 177 | + desc= |
| 178 | + f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})" |
| 179 | + ) |
| 180 | + num_tokens = MAX_M |
| 181 | + while num_tokens > 0: |
| 182 | + m_grouped_fp8_gemm_nt_contiguous( |
| 183 | + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), |
| 184 | + out[:num_tokens], expert_ids[:num_tokens]) |
| 185 | + pbar.update(1) |
| 186 | + num_tokens = num_tokens - block_m |
| 187 | + |
| 188 | + for w, ws in [(w1, w1_scale), (w2, w2_scale)]: |
| 189 | + if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: |
| 190 | + _warmup(w, ws) |
| 191 | + GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size()) |
| 192 | + |
| 193 | + |
| 194 | +def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): |
| 195 | + dg_modules = [ |
| 196 | + m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m) |
| 197 | + ] |
| 198 | + |
| 199 | + for dgm in dg_modules: |
| 200 | + w, ws, _ = _extract_data_from_linear_base_module(dgm) |
| 201 | + _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) |
| 202 | + |
| 203 | + |
| 204 | +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): |
| 205 | + dg_modules = [ |
| 206 | + m for m in model.modules() |
| 207 | + if _fused_moe_grouped_gemm_may_use_deep_gemm(m) |
| 208 | + ] |
| 209 | + |
| 210 | + for dgm in dg_modules: |
| 211 | + w13, w13_scale, w2, w2_scale, num_topk = ( |
| 212 | + _extract_data_from_fused_moe_module(dgm)) |
| 213 | + _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( |
| 214 | + w13, w2, w13_scale, w2_scale, num_topk) |
| 215 | + |
| 216 | + |
| 217 | +def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): |
| 218 | + deepgemm_fp8_gemm_nt_warmup(model, max_tokens) |
| 219 | + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) |
0 commit comments