Skip to content

Commit f703b92

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] DeepGEMM : Avoid JIT generation in the hot-path (#22215)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent cd9b9de commit f703b92

File tree

5 files changed

+274
-37
lines changed

5 files changed

+274
-37
lines changed

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,6 @@ def apply(
237237
assert w1_scale is not None
238238
assert w2_scale is not None
239239

240-
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
241-
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
242-
# to happen during actual model-inference. The
243-
# `warmup_deepgemm_kernels` function is a `run_once` decorated
244-
# function that executes during the model profile run. This warmup
245-
# should create all the required JITs for the current model.
246-
warmup_deepgemm_gg_contiguous_kernels(w1,
247-
w2,
248-
w1_scale,
249-
w2_scale,
250-
num_topk=topk_ids.size(1))
251-
252240
a1q = hidden_states
253241
_, N, K = w1.size()
254242

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import functools
55
import json
66
import os
7+
# torch.compile needs typing.List. It will fail torch.library.infer_schema
8+
# otherwise
9+
from typing import List # noqa: UP035
710
from typing import Any, Callable, Optional
811

912
import torch
@@ -998,29 +1001,30 @@ def get_config_dtype_str(
9981001
return None
9991002

10001003

1001-
def inplace_fused_experts(hidden_states: torch.Tensor,
1002-
w1: torch.Tensor,
1003-
w2: torch.Tensor,
1004-
topk_weights: torch.Tensor,
1005-
topk_ids: torch.Tensor,
1006-
activation: str = "silu",
1007-
is_act_and_mul: bool = True,
1008-
apply_router_weight_on_input: bool = False,
1009-
use_fp8_w8a8: bool = False,
1010-
use_int8_w8a8: bool = False,
1011-
use_int8_w8a16: bool = False,
1012-
use_int4_w4a16: bool = False,
1013-
use_mxfp4_w4a4: bool = False,
1014-
per_channel_quant: bool = False,
1015-
global_num_experts: int = -1,
1016-
expert_map: Optional[torch.Tensor] = None,
1017-
w1_scale: Optional[torch.Tensor] = None,
1018-
w2_scale: Optional[torch.Tensor] = None,
1019-
w1_zp: Optional[torch.Tensor] = None,
1020-
w2_zp: Optional[torch.Tensor] = None,
1021-
a1_scale: Optional[torch.Tensor] = None,
1022-
a2_scale: Optional[torch.Tensor] = None,
1023-
block_shape: Optional[list[int]] = None) -> None:
1004+
def inplace_fused_experts(
1005+
hidden_states: torch.Tensor,
1006+
w1: torch.Tensor,
1007+
w2: torch.Tensor,
1008+
topk_weights: torch.Tensor,
1009+
topk_ids: torch.Tensor,
1010+
activation: str = "silu",
1011+
is_act_and_mul: bool = True,
1012+
apply_router_weight_on_input: bool = False,
1013+
use_fp8_w8a8: bool = False,
1014+
use_int8_w8a8: bool = False,
1015+
use_int8_w8a16: bool = False,
1016+
use_int4_w4a16: bool = False,
1017+
use_mxfp4_w4a4: bool = False,
1018+
per_channel_quant: bool = False,
1019+
global_num_experts: int = -1,
1020+
expert_map: Optional[torch.Tensor] = None,
1021+
w1_scale: Optional[torch.Tensor] = None,
1022+
w2_scale: Optional[torch.Tensor] = None,
1023+
w1_zp: Optional[torch.Tensor] = None,
1024+
w2_zp: Optional[torch.Tensor] = None,
1025+
a1_scale: Optional[torch.Tensor] = None,
1026+
a2_scale: Optional[torch.Tensor] = None,
1027+
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
10241028
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
10251029
activation, is_act_and_mul,
10261030
apply_router_weight_on_input, use_fp8_w8a8,
@@ -1082,7 +1086,7 @@ def flashinfer_fused_moe_blockscale_fp8(
10821086
intermediate_size: int,
10831087
expert_offset: int,
10841088
local_num_experts: int,
1085-
block_shape: list[int],
1089+
block_shape: List[int], #noqa: UP006
10861090
routed_scaling: float = 1.0) -> torch.Tensor:
10871091
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
10881092
assert top_k <= global_num_experts
@@ -1264,7 +1268,8 @@ def outplace_fused_experts(
12641268
w2_zp: Optional[torch.Tensor] = None,
12651269
a1_scale: Optional[torch.Tensor] = None,
12661270
a2_scale: Optional[torch.Tensor] = None,
1267-
block_shape: Optional[list[int]] = None) -> torch.Tensor:
1271+
block_shape: Optional[List[int]] = None, #noqa: UP006
1272+
) -> torch.Tensor:
12681273
return fused_experts_impl(
12691274
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
12701275
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Warmup deep_gemm kernels.
5+
DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
6+
be used during model execution beforehand.
7+
"""
8+
9+
import torch
10+
from tqdm import tqdm
11+
12+
import vllm.envs as envs
13+
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
14+
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
15+
compute_aligned_M, deep_gemm_block_shape)
16+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
17+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
18+
FusedMoEModularKernel)
19+
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
20+
TritonOrDeepGemmExperts)
21+
from vllm.model_executor.layers.linear import LinearBase
22+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
23+
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
24+
25+
26+
def _extract_data_from_linear_base_module(
27+
m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
28+
"""
29+
Extract weights, weight scales and quantization block sizes from the given
30+
LinearBase module.
31+
"""
32+
assert isinstance(m, LinearBase)
33+
assert isinstance(m.quant_method, Fp8LinearMethod)
34+
assert m.quant_method.block_quant
35+
assert m.quant_method.quant_config is not None
36+
37+
w = m.weight
38+
ws = m.weight_scale_inv
39+
quant_block_size = m.quant_method.quant_config.weight_block_size
40+
41+
assert isinstance(w, torch.Tensor)
42+
assert isinstance(ws, torch.Tensor)
43+
assert quant_block_size is not None
44+
return (w, ws, quant_block_size)
45+
46+
47+
def _extract_data_from_fused_moe_module(
48+
m: torch.nn.Module
49+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
50+
"""
51+
Extract weights, weight scales and num_topk from FusedMoE module.
52+
"""
53+
assert isinstance(m, FusedMoE)
54+
w13 = m.w13_weight
55+
w13_s = m.w13_weight_scale_inv
56+
w2 = m.w2_weight
57+
w2_s = m.w2_weight_scale_inv
58+
num_topk = m.top_k
59+
60+
assert isinstance(w13, torch.Tensor)
61+
assert isinstance(w13_s, torch.Tensor)
62+
assert isinstance(w2, torch.Tensor)
63+
assert isinstance(w2_s, torch.Tensor)
64+
return w13, w13_s, w2, w2_s, num_topk
65+
66+
67+
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
68+
"""
69+
Return True if the input module/layer could be processed with DeepGEMM.
70+
"""
71+
block_size = deep_gemm_block_shape()[0]
72+
if not (isinstance(module, LinearBase)
73+
and isinstance(module.quant_method, Fp8LinearMethod)
74+
and module.quant_method.block_quant):
75+
return False
76+
77+
w, _, block_sizes = _extract_data_from_linear_base_module(module)
78+
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
79+
and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
80+
81+
82+
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
83+
if not (isinstance(module, FusedMoE)
84+
and module.moe_config.quant_dtype == torch.float8_e4m3fn
85+
and module.moe_config.block_shape == deep_gemm_block_shape()):
86+
return False
87+
88+
if not isinstance(module.quant_method.fused_experts,
89+
FusedMoEModularKernel):
90+
# fused_experts could invoke deep_gemm_moe_fp8
91+
return True
92+
93+
mk: FusedMoEModularKernel = module.quant_method.fused_experts
94+
# Further check if the ModularKernel implementation uses the DeepGemmExperts
95+
return isinstance(mk.fused_experts,
96+
(DeepGemmExperts, TritonOrDeepGemmExperts))
97+
98+
99+
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
100+
101+
102+
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
103+
max_tokens: int):
104+
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
105+
return
106+
107+
n, k = w.size()
108+
block_m = deep_gemm_block_shape()[0]
109+
110+
device = w.device
111+
a1q = torch.empty((max_tokens, k),
112+
device=device,
113+
dtype=torch.float8_e4m3fn)
114+
a1q_scales = torch.empty((max_tokens, k // block_m),
115+
device=device,
116+
dtype=torch.float32)
117+
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
118+
119+
pbar = tqdm(total=max_tokens,
120+
desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
121+
num_tokens = max_tokens
122+
while num_tokens > 0:
123+
fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
124+
out[:num_tokens])
125+
pbar.update(1)
126+
num_tokens -= 1
127+
128+
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
129+
130+
131+
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
132+
133+
134+
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
135+
w2: torch.Tensor,
136+
w1_scale: torch.Tensor,
137+
w2_scale: torch.Tensor,
138+
num_topk: int):
139+
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
140+
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
141+
return
142+
143+
assert w1.size(0) == w2.size(0), (
144+
"w1 and w2 must have the same number of experts")
145+
146+
block_m = deep_gemm_block_shape()[0]
147+
num_experts = w1.size(0)
148+
device = w1.device
149+
150+
# This is the maximum GroupedGemm M size that we expect to run
151+
# the grouped_gemm with.
152+
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
153+
num_topk,
154+
num_experts,
155+
block_m,
156+
expert_tokens_meta=None)
157+
# Distribute expert-ids evenly.
158+
MAX_BLOCKS = MAX_M // block_m
159+
expert_ids_block = torch.randint(low=0,
160+
high=num_experts,
161+
size=(MAX_BLOCKS, ),
162+
device=device,
163+
dtype=torch.int32)
164+
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
165+
166+
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
167+
168+
_, n, k = w.size()
169+
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
170+
a1q_scales = torch.empty((MAX_M, k // block_m),
171+
device=device,
172+
dtype=torch.float32)
173+
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
174+
175+
pbar = tqdm(
176+
total=MAX_BLOCKS,
177+
desc=
178+
f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
179+
)
180+
num_tokens = MAX_M
181+
while num_tokens > 0:
182+
m_grouped_fp8_gemm_nt_contiguous(
183+
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
184+
out[:num_tokens], expert_ids[:num_tokens])
185+
pbar.update(1)
186+
num_tokens = num_tokens - block_m
187+
188+
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
189+
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
190+
_warmup(w, ws)
191+
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
192+
193+
194+
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
195+
dg_modules = [
196+
m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
197+
]
198+
199+
for dgm in dg_modules:
200+
w, ws, _ = _extract_data_from_linear_base_module(dgm)
201+
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
202+
203+
204+
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
205+
dg_modules = [
206+
m for m in model.modules()
207+
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
208+
]
209+
210+
for dgm in dg_modules:
211+
w13, w13_scale, w2, w2_scale, num_topk = (
212+
_extract_data_from_fused_moe_module(dgm))
213+
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
214+
w13, w2, w13_scale, w2_scale, num_topk)
215+
216+
217+
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
218+
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
219+
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Warmup kernels used during model execution.
5+
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
6+
happen during model execution.
7+
"""
8+
import torch
9+
10+
import vllm.envs as envs
11+
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
12+
from vllm.utils.deep_gemm import is_deep_gemm_supported
13+
14+
15+
def kernel_warmup(model: torch.nn.Module, max_tokens: int):
16+
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
17+
and is_deep_gemm_supported()
18+
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
19+
if do_deep_gemm_warmup:
20+
deep_gemm_warmup(model, max_tokens)

vllm/v1/worker/gpu_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.logger import init_logger
2222
from vllm.lora.request import LoRARequest
2323
from vllm.model_executor import set_random_seed
24+
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
2425
from vllm.platforms import current_platform
2526
from vllm.sequence import IntermediateTensors
2627
from vllm.tasks import SupportedTask
@@ -338,6 +339,10 @@ def compile_or_warm_up_model(self) -> None:
338339
self.model_runner._dummy_sampler_run(
339340
hidden_states=last_hidden_states)
340341

342+
# Warmup kernels used during model execution
343+
kernel_warmup(self.get_model(),
344+
max_tokens=self.scheduler_config.max_num_batched_tokens)
345+
341346
# Reset the seed to ensure that the random state is not affected by
342347
# the model initialization and profiling.
343348
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)