Skip to content

Commit a3970e8

Browse files
authored
[0.9.1][Bugfix][Aclgraph] Fix qwen3-moe + aclgraph + tp (#2647)
### What this PR does / why we need it? Qwen3 moe + Aclgraph only support pure tp scenario. Cause in v0.9.1-dev branch, aclgraph only support allgather. This pr change to use `AscendSparseMoeBlock` to make aclgraph work with tp ### Does this PR introduce _any_ user-facing change? Users could only run Qwen3 moe + pure tp when enabling aclgraph ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: MengqingCao <[email protected]>
1 parent 234a5a4 commit a3970e8

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

tests/multicard/test_qwen3_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@
2424
from tests.conftest import VllmRunner
2525

2626

27-
def test_models_distributed_Qwen3_MOE_TP2():
27+
def test_models_distributed_Qwen3_MOE_Aclgraph_TP2():
2828
example_prompts = [
2929
"Hello, my name is",
3030
]
31-
dtype = "half"
31+
dtype = "bfloat16"
3232
max_tokens = 5
3333
with VllmRunner(
3434
"Qwen/Qwen3-30B-A3B",
3535
dtype=dtype,
3636
tensor_parallel_size=4,
3737
distributed_executor_backend="mp",
38+
enforce_eager=False,
3839
) as vllm_model:
3940
vllm_model.generate_greedy(example_prompts, max_tokens)
4041

vllm_ascend/models/qwen3_moe.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323
from transformers import PretrainedConfig
2424
from vllm.compilation.decorators import support_torch_compile
25-
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
25+
from vllm.config import CacheConfig, VllmConfig
2626
from vllm.distributed import get_pp_group
2727
from vllm.model_executor.layers.layernorm import RMSNorm
2828
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -32,8 +32,7 @@
3232
from vllm.model_executor.models.interfaces import SupportsPP
3333
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
3434
Qwen3MoeForCausalLM,
35-
Qwen3MoeMLP, Qwen3MoeModel,
36-
Qwen3MoeSparseMoeBlock)
35+
Qwen3MoeMLP, Qwen3MoeModel)
3736
from vllm.model_executor.models.utils import (
3837
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
3938
maybe_prefix)
@@ -79,21 +78,12 @@ def __init__(
7978
layer_idx = extract_layer_index(prefix)
8079
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
8180
config.mlp_only_layers)
82-
use_aclgraph = (vllm_config is not None
83-
and vllm_config.compilation_config.level
84-
== CompilationLevel.PIECEWISE
85-
and not vllm_config.model_config.enforce_eager)
8681
if (layer_idx not in mlp_only_layers) and (
8782
config.num_experts > 0 and
8883
(layer_idx + 1) % config.decoder_sparse_step == 0):
89-
if not use_aclgraph:
90-
self.mlp = AscendSparseMoeBlock(config=config,
91-
quant_config=quant_config,
92-
prefix=f"{prefix}.mlp")
93-
else:
94-
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
95-
quant_config=quant_config,
96-
prefix=f"{prefix}.mlp")
84+
self.mlp = AscendSparseMoeBlock(config=config,
85+
quant_config=quant_config,
86+
prefix=f"{prefix}.mlp")
9787
else:
9888
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
9989
intermediate_size=config.intermediate_size,

0 commit comments

Comments
 (0)