Skip to content

Commit 2e6bc46

Browse files
[Startup] Make DeepGEMM warmup scale with max-num-batched-tokens (#24693)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent fcba05c commit 2e6bc46

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

vllm/model_executor/warmup/deep_gemm_warmup.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tqdm import tqdm
1111

1212
import vllm.envs as envs
13+
from vllm.distributed.parallel_state import get_dp_group
1314
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
1415
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
1516
compute_aligned_M, deep_gemm_block_shape)
@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
131132
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
132133

133134

134-
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
135-
w2: torch.Tensor,
136-
w1_scale: torch.Tensor,
137-
w2_scale: torch.Tensor,
138-
num_topk: int):
135+
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
136+
w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor,
137+
w2_scale: torch.Tensor, num_topk: int, max_tokens: int):
139138
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
140139
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
141140
return
@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
147146
num_experts = w1.size(0)
148147
device = w1.device
149148

149+
# Assumes all ranks have the same max_num_batched_tokens
150+
max_tokens_across_dp = get_dp_group().world_size * max_tokens
151+
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
152+
150153
# This is the maximum GroupedGemm M size that we expect to run
151154
# the grouped_gemm with.
152-
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
155+
MAX_M = compute_aligned_M(max_tokens,
153156
num_topk,
154157
num_experts,
155158
block_m,
@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
201204
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
202205

203206

204-
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
207+
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module,
208+
max_tokens: int):
205209
dg_modules = [
206210
m for m in model.modules()
207211
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
211215
w13, w13_scale, w2, w2_scale, num_topk = (
212216
_extract_data_from_fused_moe_module(dgm))
213217
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
214-
w13, w2, w13_scale, w2_scale, num_topk)
218+
w13, w2, w13_scale, w2_scale, num_topk, max_tokens)
215219

216220

217221
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
218222
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
219-
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)
223+
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)

0 commit comments

Comments
 (0)