Skip to content

Commit 1840c5c

Browse files
[BugFix] Make sure to allocate worst case MoE workspace during profile run in the DP + EP case (#27426)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 1bed891 commit 1840c5c

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

vllm/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
VLLM_CPU_SGL_KERNEL: bool = False
5656
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
5757
VLLM_XLA_CHECK_RECOMPILATION: bool = False
58-
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
58+
VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
5959
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
6060
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
6161
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
@@ -785,7 +785,7 @@ def get_vllm_port() -> int | None:
785785
# Enable SPMD mode for TPU backend.
786786
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
787787
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
788-
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")
788+
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
789789
),
790790
# Control whether to use fused MoE activation chunking. Current chunking
791791
# logic is incompatible with torch.compile and causes IMA. See issue

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import torch
1111

1212
import vllm.envs as envs
13+
from vllm.config import get_current_vllm_config
14+
from vllm.forward_context import get_forward_context, is_forward_context_available
15+
from vllm.logger import init_logger
1316
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1417
from vllm.model_executor.layers.fused_moe.utils import (
1518
_resize_cache,
@@ -26,6 +29,8 @@
2629
dbo_yield,
2730
)
2831

32+
logger = init_logger(__name__)
33+
2934
#
3035
# This file defines a set of base classes used to make MoE kernels more modular.
3136
# The goal is to be able to utilize different communication mechanisms with
@@ -798,6 +803,42 @@ def _allocate_buffers(
798803
buffers = self.shared_buffers[ubatch_idx]
799804
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
800805

806+
# Force worst-case allocation in profiling run for
807+
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
808+
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
809+
# DP+EP due to the random token routing.
810+
is_profile_run = (
811+
is_forward_context_available()
812+
and get_forward_context().attn_metadata is None
813+
)
814+
if is_profile_run and self.fused_experts.supports_chunking():
815+
parallel_config = get_current_vllm_config().parallel_config
816+
is_dp_ep = (
817+
parallel_config.data_parallel_size > 1
818+
and parallel_config.enable_expert_parallel
819+
)
820+
if is_dp_ep:
821+
max_workspace_13, max_workspace_2, max_fused_out_shape = (
822+
self.fused_experts.workspace_shapes(
823+
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
824+
N,
825+
K,
826+
top_k,
827+
global_num_experts,
828+
local_num_experts,
829+
expert_tokens_meta,
830+
)
831+
)
832+
buffers.workspace13.get(
833+
max_workspace_13, device=device, dtype=workspace_dtype
834+
)
835+
buffers.workspace2.get(
836+
max_workspace_2, device=device, dtype=workspace_dtype
837+
)
838+
buffers.fused_out.get(
839+
max_fused_out_shape, device=device, dtype=workspace_dtype
840+
)
841+
801842
# Get intermediate workspace shapes based off the chunked M size.
802843
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
803844
M_chunk,

0 commit comments

Comments
 (0)