|
4 | 4 | from typing import Any, Optional
|
5 | 5 |
|
6 | 6 | import torch
|
| 7 | +from tqdm import tqdm |
7 | 8 |
|
| 9 | +import vllm.envs as env |
8 | 10 | import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
9 | 11 | from vllm.logger import init_logger
|
10 | 12 | from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|
17 | 19 | from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
18 | 20 | from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
19 | 21 | per_token_group_quant_fp8)
|
20 |
| -from vllm.utils import has_deep_gemm |
| 22 | +from vllm.utils import has_deep_gemm, run_once |
21 | 23 | from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
22 | 24 |
|
23 | 25 | logger = init_logger(__name__)
|
@@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
82 | 84 | return True
|
83 | 85 |
|
84 | 86 |
|
| 87 | +@run_once |
| 88 | +def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, |
| 89 | + w1_scale: torch.Tensor, |
| 90 | + w2_scale: torch.Tensor, |
| 91 | + num_topk: int): |
| 92 | + """ |
| 93 | + DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the |
| 94 | + input tensor shapes. In this function, we construct all possible input |
| 95 | + tensor shapes so all the kernels are JIT'ed and cached. |
| 96 | + Note that this warmup is expected to happen during the model profile |
| 97 | + call and not during actual model inference. |
| 98 | + """ |
| 99 | + |
| 100 | + assert w1.size(0) == w2.size(0), ( |
| 101 | + "w1 and w2 must have the same number of experts") |
| 102 | + |
| 103 | + block_m = deep_gemm_block_shape()[0] |
| 104 | + num_experts = w1.size(0) |
| 105 | + device = w1.device |
| 106 | + |
| 107 | + # This is the maximum GroupedGemm M size that we expect to run |
| 108 | + # the grouped_gemm with. |
| 109 | + MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE, |
| 110 | + num_topk, |
| 111 | + num_experts, |
| 112 | + block_m, |
| 113 | + expert_tokens_meta=None) |
| 114 | + # Distribute expert-ids evenly. |
| 115 | + MAX_BLOCKS = MAX_M // block_m |
| 116 | + expert_ids_block = torch.randint(low=0, |
| 117 | + high=num_experts, |
| 118 | + size=(MAX_BLOCKS, ), |
| 119 | + device=device, |
| 120 | + dtype=torch.int32) |
| 121 | + expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) |
| 122 | + |
| 123 | + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): |
| 124 | + |
| 125 | + _, n, k = w.size() |
| 126 | + a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) |
| 127 | + a1q_scales = torch.empty((MAX_M, k // block_m), |
| 128 | + device=device, |
| 129 | + dtype=torch.float32) |
| 130 | + out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) |
| 131 | + |
| 132 | + pbar = tqdm(total=MAX_BLOCKS, |
| 133 | + desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})") |
| 134 | + num_tokens = MAX_M |
| 135 | + while num_tokens > 0: |
| 136 | + m_grouped_fp8_gemm_nt_contiguous( |
| 137 | + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), |
| 138 | + out[:num_tokens], expert_ids[:num_tokens]) |
| 139 | + pbar.update(1) |
| 140 | + num_tokens = num_tokens - block_m |
| 141 | + |
| 142 | + _warmup(w1, w1_scale) |
| 143 | + _warmup(w2, w2_scale) |
| 144 | + |
| 145 | + |
85 | 146 | class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
86 | 147 |
|
87 | 148 | def __init__(self):
|
@@ -156,6 +217,20 @@ def apply(
|
156 | 217 | ):
|
157 | 218 | assert self.block_shape is not None
|
158 | 219 | assert a1q_scale is not None
|
| 220 | + assert w1_scale is not None |
| 221 | + assert w2_scale is not None |
| 222 | + |
| 223 | + if not env.VLLM_SKIP_DEEP_GEMM_WARMUP: |
| 224 | + # DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing |
| 225 | + # to happen during actual model-inference. The |
| 226 | + # `warmup_deepgemm_kernels` function is a `run_once` decorated |
| 227 | + # function that executes during the model profile run. This warmup |
| 228 | + # should create all the required JITs for the current model. |
| 229 | + warmup_deepgemm_gg_contiguous_kernels(w1, |
| 230 | + w2, |
| 231 | + w1_scale, |
| 232 | + w2_scale, |
| 233 | + num_topk=topk_ids.size(1)) |
159 | 234 |
|
160 | 235 | a1q = hidden_states
|
161 | 236 | _, N, K = w1.size()
|
|
0 commit comments