Skip to content

Commit a65f46b

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 5739371 commit a65f46b

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
127127
VLLM_TPU_USING_PATHWAYS: bool = False
128128
VLLM_USE_DEEP_GEMM: bool = False
129+
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
129130
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
130131
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
131132
VLLM_XGRAMMAR_CACHE_MB: int = 0
@@ -910,6 +911,14 @@ def get_vllm_port() -> Optional[int]:
910911
"VLLM_USE_DEEP_GEMM":
911912
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
912913

914+
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
915+
# JIT all the required kernels before model execution so there is no
916+
# JIT'ing in the hot-path. However, this warmup increases the engine
917+
# startup time by a couple of minutes.
918+
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
919+
"VLLM_SKIP_DEEP_GEMM_WARMUP":
920+
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
921+
913922
# Allow use of FlashInfer MoE kernels for fused moe ops.
914923
"VLLM_USE_FLASHINFER_MOE_FP8":
915924
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Any, Optional
55

66
import torch
7+
from tqdm import tqdm
78

9+
import vllm.envs as env
810
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
911
from vllm.logger import init_logger
1012
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@@ -17,7 +19,7 @@
1719
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1820
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1921
per_token_group_quant_fp8)
20-
from vllm.utils import has_deep_gemm
22+
from vllm.utils import has_deep_gemm, run_once
2123
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
2224

2325
logger = init_logger(__name__)
@@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
8284
return True
8385

8486

87+
@run_once
88+
def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
89+
w1_scale: torch.Tensor,
90+
w2_scale: torch.Tensor,
91+
num_topk: int):
92+
"""
93+
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
94+
input tensor shapes. In this function, we construct all possible input
95+
tensor shapes so all the kernels are JIT'ed and cached.
96+
Note that this warmup is expected to happen during the model profile
97+
call and not during actual model inference.
98+
"""
99+
100+
assert w1.size(0) == w2.size(0), (
101+
"w1 and w2 must have the same number of experts")
102+
103+
block_m = deep_gemm_block_shape()[0]
104+
num_experts = w1.size(0)
105+
device = w1.device
106+
107+
# This is the maximum GroupedGemm M size that we expect to run
108+
# the grouped_gemm with.
109+
MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE,
110+
num_topk,
111+
num_experts,
112+
block_m,
113+
expert_tokens_meta=None)
114+
# Distribute expert-ids evenly.
115+
MAX_BLOCKS = MAX_M // block_m
116+
expert_ids_block = torch.randint(low=0,
117+
high=num_experts,
118+
size=(MAX_BLOCKS, ),
119+
device=device,
120+
dtype=torch.int32)
121+
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
122+
123+
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
124+
125+
_, n, k = w.size()
126+
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
127+
a1q_scales = torch.empty((MAX_M, k // block_m),
128+
device=device,
129+
dtype=torch.float32)
130+
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
131+
132+
pbar = tqdm(total=MAX_BLOCKS,
133+
desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})")
134+
num_tokens = MAX_M
135+
while num_tokens > 0:
136+
m_grouped_fp8_gemm_nt_contiguous(
137+
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
138+
out[:num_tokens], expert_ids[:num_tokens])
139+
pbar.update(1)
140+
num_tokens = num_tokens - block_m
141+
142+
_warmup(w1, w1_scale)
143+
_warmup(w2, w2_scale)
144+
145+
85146
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
86147

87148
def __init__(self):
@@ -156,6 +217,20 @@ def apply(
156217
):
157218
assert self.block_shape is not None
158219
assert a1q_scale is not None
220+
assert w1_scale is not None
221+
assert w2_scale is not None
222+
223+
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
224+
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
225+
# to happen during actual model-inference. The
226+
# `warmup_deepgemm_kernels` function is a `run_once` decorated
227+
# function that executes during the model profile run. This warmup
228+
# should create all the required JITs for the current model.
229+
warmup_deepgemm_gg_contiguous_kernels(w1,
230+
w2,
231+
w1_scale,
232+
w2_scale,
233+
num_topk=topk_ids.size(1))
159234

160235
a1q = hidden_states
161236
_, N, K = w1.size()

vllm/utils/deep_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import functools
1010
import importlib
11+
import os
1112
from typing import Any, Callable, NoReturn
1213

1314
import torch
@@ -77,6 +78,12 @@ def _lazy_init() -> None:
7778
if not has_deep_gemm():
7879
return
7980

81+
# Set up deep_gemm cache path
82+
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
83+
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
84+
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
85+
envs.VLLM_CACHE_ROOT, "deep_gemm")
86+
8087
_dg = importlib.import_module("deep_gemm")
8188

8289
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",

0 commit comments

Comments
 (0)