Skip to content

Commit 458ab2d

Browse files
authored
[BugFix] Fix the bug that qwen3 moe doesn't work with aclgraph (#2183)
What's the PR does: 1. Move AscendSparseMoeBlock to qwen3 model, since it's only used by qwen3 model. 2. Disable AscendSparseMoeBlock if aclgraph is enabled, AscendSparseMoeBlock doesn't work with aclgraph currently. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@cdfd687 Signed-off-by: wangxiyuan <[email protected]>
1 parent 583ad8f commit 458ab2d

File tree

3 files changed

+151
-86
lines changed

3 files changed

+151
-86
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
18+
#
19+
"""Compare the short outputs of HF and vLLM when using greedy sampling.
20+
21+
Run `pytest tests/test_offline_inference.py`.
22+
"""
23+
24+
from tests.e2e.conftest import VllmRunner
25+
26+
27+
def test_models_distributed_Qwen3_MOE_TP2():
28+
example_prompts = [
29+
"Hello, my name is",
30+
]
31+
dtype = "half"
32+
max_tokens = 5
33+
with VllmRunner(
34+
"Qwen/Qwen3-30B-A3B",
35+
dtype=dtype,
36+
tensor_parallel_size=2,
37+
distributed_executor_backend="mp",
38+
) as vllm_model:
39+
vllm_model.generate_greedy(example_prompts, max_tokens)
40+
41+
42+
def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
43+
example_prompts = [
44+
"Hello, my name is",
45+
]
46+
dtype = "half"
47+
max_tokens = 5
48+
with VllmRunner(
49+
"Qwen/Qwen3-30B-A3B",
50+
dtype=dtype,
51+
tensor_parallel_size=2,
52+
enable_expert_parallel=True,
53+
distributed_executor_backend="mp",
54+
) as vllm_model:
55+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/models/qwen3_moe.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,102 @@
1717
# This file is a part of the vllm-ascend project.
1818
from typing import Optional
1919

20+
import torch
2021
from torch import nn
2122
from transformers import PretrainedConfig
2223
from vllm.compilation.decorators import support_torch_compile
23-
from vllm.config import CacheConfig
24+
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
25+
from vllm.distributed import get_tensor_model_parallel_world_size
26+
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
27+
get_tp_group)
28+
from vllm.forward_context import get_forward_context
2429
from vllm.model_executor.layers.layernorm import RMSNorm
30+
from vllm.model_executor.layers.linear import ReplicatedLinear
2531
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2632
from vllm.model_executor.layers.quantization import QuantizationConfig
2733
from vllm.model_executor.layers.vocab_parallel_embedding import (
2834
ParallelLMHead, VocabParallelEmbedding)
2935
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
3036
Qwen3MoeDecoderLayer,
3137
Qwen3MoeForCausalLM,
32-
Qwen3MoeMLP, Qwen3MoeModel)
38+
Qwen3MoeMLP, Qwen3MoeModel,
39+
Qwen3MoeSparseMoeBlock)
3340
from vllm.model_executor.models.utils import (
3441
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
3542
maybe_prefix)
3643

37-
from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock
38-
from vllm_ascend.platform import VllmConfig
44+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
45+
46+
47+
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
48+
49+
def __init__(
50+
self,
51+
config: PretrainedConfig,
52+
quant_config: Optional[QuantizationConfig] = None,
53+
prefix: str = "",
54+
):
55+
nn.Module.__init__(self)
56+
self.tp_size = get_tensor_model_parallel_world_size()
57+
if self.tp_size > config.num_experts:
58+
raise ValueError(
59+
f"Tensor parallel size {self.tp_size} is greater than "
60+
f"the number of experts {config.num_experts}.")
61+
62+
self.gate = ReplicatedLinear(
63+
config.hidden_size,
64+
config.num_experts,
65+
bias=False,
66+
quant_config=None,
67+
prefix=f"{prefix}.gate",
68+
)
69+
70+
self.experts = AscendFusedMoE(
71+
num_experts=config.num_experts,
72+
top_k=config.num_experts_per_tok,
73+
hidden_size=config.hidden_size,
74+
intermediate_size=config.moe_intermediate_size,
75+
reduce_results=False,
76+
renormalize=config.norm_topk_prob,
77+
quant_config=quant_config,
78+
prefix=f"{prefix}.experts",
79+
)
80+
81+
self.top_k = config.num_experts_per_tok
82+
83+
self.dp_size = get_dp_group().world_size
84+
85+
self.tp_group = get_tp_group().device_group
86+
self.tp_rank = get_tp_group().rank_in_group
87+
self.ep_group = get_ep_group()
88+
89+
self.params_dtype = torch.get_default_dtype()
90+
91+
def forward(
92+
self,
93+
hidden_states,
94+
attn_metadata=None,
95+
):
96+
if attn_metadata is None:
97+
attn_metadata = get_forward_context().attn_metadata
98+
# when profile runs, force experts to load balanced tokens
99+
# to avoid high memory consumption on a single rank.
100+
enable_force_load_balance = get_forward_context().in_profile_run
101+
is_prefill = get_forward_context().with_prefill
102+
103+
# router_logits: (num_tokens, n_experts)
104+
router_logits, _ = self.gate(hidden_states)
105+
106+
hidden_states = self.experts(
107+
hidden_states=hidden_states,
108+
router_logits=router_logits,
109+
is_prefill=is_prefill,
110+
top_k=self.top_k,
111+
enable_force_load_balance=enable_force_load_balance,
112+
shared_experts=None,
113+
)
114+
115+
return hidden_states
39116

40117

41118
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
@@ -45,6 +122,7 @@ def __init__(
45122
config: PretrainedConfig,
46123
cache_config: Optional[CacheConfig] = None,
47124
quant_config: Optional[QuantizationConfig] = None,
125+
vllm_config: Optional[VllmConfig] = None,
48126
prefix: str = "",
49127
) -> None:
50128

@@ -73,12 +151,22 @@ def __init__(
73151
layer_idx = extract_layer_index(prefix)
74152
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
75153
config.mlp_only_layers)
154+
use_aclgraph = (vllm_config is not None
155+
and vllm_config.compilation_config.level
156+
== CompilationLevel.PIECEWISE
157+
and not vllm_config.model_config.enforce_eager)
76158
if (layer_idx not in mlp_only_layers) and (
77159
config.num_experts > 0 and
78160
(layer_idx + 1) % config.decoder_sparse_step == 0):
79-
self.mlp = AscendSparseMoeBlock(config=config,
80-
quant_config=quant_config,
81-
prefix=f"{prefix}.mlp")
161+
if not use_aclgraph:
162+
# FIXME: custom sparse moe block doesn't work with aclgraph.
163+
self.mlp = CustomSparseMoeBlock(config=config,
164+
quant_config=quant_config,
165+
prefix=f"{prefix}.mlp")
166+
else:
167+
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
168+
quant_config=quant_config,
169+
prefix=f"{prefix}.mlp")
82170
else:
83171
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
84172
intermediate_size=config.intermediate_size,
@@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115203
config=config,
116204
cache_config=cache_config,
117205
quant_config=quant_config,
206+
vllm_config=vllm_config,
118207
prefix=prefix),
119208
prefix=f"{prefix}.layers",
120209
)

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import torch.distributed as dist
2323
import torch_npu
2424
from torch import nn
25-
from transformers import PretrainedConfig
26-
from vllm.attention import AttentionMetadata
2725
from vllm.config import get_current_vllm_config
2826
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
2927
get_tensor_model_parallel_world_size,
@@ -37,7 +35,6 @@
3735
FusedMoEParallelConfig # isort: skip
3836
from vllm.model_executor.layers.fused_moe.layer import (
3937
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
40-
from vllm.model_executor.layers.linear import ReplicatedLinear
4138
from vllm.model_executor.layers.quantization.base_config import \
4239
QuantizationConfig
4340

@@ -1546,79 +1543,3 @@ def _forward_ms_fused_moe_comp(
15461543
)
15471544

15481545
return hidden_states
1549-
1550-
1551-
class AscendSparseMoeBlock(nn.Module):
1552-
1553-
def __init__(
1554-
self,
1555-
config: PretrainedConfig,
1556-
quant_config: Optional[QuantizationConfig] = None,
1557-
prefix: str = "",
1558-
):
1559-
super().__init__()
1560-
self.tp_size = get_tensor_model_parallel_world_size()
1561-
if self.tp_size > config.num_experts:
1562-
raise ValueError(
1563-
f"Tensor parallel size {self.tp_size} is greater than "
1564-
f"the number of experts {config.num_experts}.")
1565-
1566-
ascend_config = get_ascend_config()
1567-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1568-
self.enable_multistream_moe = (
1569-
ascend_config.torchair_graph_config.enable_multistream_moe)
1570-
1571-
self.gate = ReplicatedLinear(
1572-
config.hidden_size,
1573-
config.num_experts,
1574-
bias=False,
1575-
quant_config=None,
1576-
prefix=f"{prefix}.gate",
1577-
)
1578-
1579-
self.experts = AscendFusedMoE(
1580-
num_experts=config.num_experts,
1581-
top_k=config.num_experts_per_tok,
1582-
hidden_size=config.hidden_size,
1583-
intermediate_size=config.moe_intermediate_size,
1584-
reduce_results=False,
1585-
renormalize=config.norm_topk_prob,
1586-
quant_config=quant_config,
1587-
prefix=f"{prefix}.experts",
1588-
)
1589-
1590-
self.top_k = config.num_experts_per_tok
1591-
1592-
self.dp_size = get_dp_group().world_size
1593-
1594-
self.tp_group = get_tp_group().device_group
1595-
self.tp_rank = get_tp_group().rank_in_group
1596-
self.ep_group = get_ep_group()
1597-
1598-
self.params_dtype = torch.get_default_dtype()
1599-
1600-
def forward(
1601-
self,
1602-
hidden_states: torch.Tensor,
1603-
attn_metadata: Optional[AttentionMetadata] = None,
1604-
) -> torch.Tensor:
1605-
if attn_metadata is None:
1606-
attn_metadata = get_forward_context().attn_metadata
1607-
# when profile runs, force experts to load balanced tokens
1608-
# to avoid high memory consumption on a single rank.
1609-
enable_force_load_balance = get_forward_context().in_profile_run
1610-
is_prefill = get_forward_context().with_prefill
1611-
1612-
# router_logits: (num_tokens, n_experts)
1613-
router_logits, _ = self.gate(hidden_states)
1614-
1615-
hidden_states = self.experts(
1616-
hidden_states=hidden_states,
1617-
router_logits=router_logits,
1618-
is_prefill=is_prefill,
1619-
top_k=self.top_k,
1620-
enable_force_load_balance=enable_force_load_balance,
1621-
shared_experts=None,
1622-
)
1623-
1624-
return hidden_states

0 commit comments

Comments
 (0)