Skip to content

Commit 27924d0

Browse files
[Feature][310P] support shared experts path in fused MoE for qwen3.5 (#7674)
### What this PR does / why we need it? 310P originally supported only the Qwen3 series. Recent adaptation work for Qwen3.5 introduced the new shared-experts structure, which had not been considered on the 310P path, so this fix was made. The fix aligns the 310P execution flow with the A2/A3 implementation path. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? local e2e test - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: Tflowers-0129 <2906339855@qq.com>
1 parent 988c2aa commit 27924d0

File tree

5 files changed

+151
-5
lines changed

5 files changed

+151
-5
lines changed

tests/e2e/310p/multicard/test_moe_model_multicard.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,18 @@ def test_qwen3_moe_tp2_w8a8():
6363
max_model_len=16384,
6464
) as vllm_model:
6565
vllm_model.generate_greedy(example_prompts, max_tokens)
66+
67+
68+
def test_qwen3_5_moe_tp4_fp16():
69+
example_prompts = [
70+
"Hello, my name is",
71+
]
72+
max_tokens = 5
73+
with VllmRunner(
74+
"Qwen/Qwen3.5-35B-A3B",
75+
tensor_parallel_size=4,
76+
enforce_eager=True,
77+
dtype="float16",
78+
max_model_len=16384,
79+
) as vllm_model:
80+
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/e2e/310p/singlecard/test_dense_model_singlecard.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,18 @@ def test_qwen3_dense_tp1_w8a8():
4747
max_model_len=16384,
4848
) as vllm_model:
4949
vllm_model.generate_greedy(example_prompts, max_tokens)
50+
51+
52+
def test_qwen3_5_dense_tp1_fp16():
53+
example_prompts = [
54+
"Hello, my name is",
55+
]
56+
max_tokens = 5
57+
with VllmRunner(
58+
"Qwen/Qwen3.5-4B",
59+
tensor_parallel_size=1,
60+
enforce_eager=True,
61+
dtype="float16",
62+
max_model_len=16384,
63+
) as vllm_model:
64+
vllm_model.generate_greedy(example_prompts, max_tokens)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#
2+
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from unittest.mock import patch
17+
18+
import torch
19+
import torch.nn.functional as F
20+
21+
from vllm_ascend._310p.fused_moe.fused_moe import (
22+
AscendFusedMoE310,
23+
AscendSharedFusedMoE310,
24+
)
25+
26+
27+
class _DummyGate(torch.nn.Module):
28+
def forward(self, hidden_states: torch.Tensor):
29+
# Keep gate output deterministic: sigmoid(0)=0.5.
30+
return torch.zeros(
31+
hidden_states.shape[0],
32+
1,
33+
dtype=hidden_states.dtype,
34+
device=hidden_states.device,
35+
), None
36+
37+
38+
class _DummySharedExperts(torch.nn.Module):
39+
def __init__(self, with_gate: bool):
40+
super().__init__()
41+
self.expert_gate = _DummyGate() if with_gate else None
42+
43+
def forward(self, hidden_states: torch.Tensor):
44+
out = hidden_states * 2.0 + 1.0
45+
if self.expert_gate is not None:
46+
gate_out, _ = self.expert_gate(hidden_states)
47+
out = F.sigmoid(gate_out) * out
48+
return out
49+
50+
51+
def _build_layer(shared_experts: torch.nn.Module | None) -> AscendSharedFusedMoE310:
52+
layer = AscendSharedFusedMoE310.__new__(AscendSharedFusedMoE310)
53+
# The test bypasses full layer init with __new__, so we must initialize
54+
# nn.Module internals before assigning child modules.
55+
torch.nn.Module.__init__(layer)
56+
layer._shared_experts = shared_experts
57+
return layer
58+
59+
60+
def test_forward_shared_experts_without_gate_310():
61+
layer = _build_layer(_DummySharedExperts(with_gate=False))
62+
hidden_states = torch.randn(4, 8)
63+
output = layer._forward_shared_experts(hidden_states)
64+
expected = hidden_states * 2.0 + 1.0
65+
torch.testing.assert_close(output, expected)
66+
67+
68+
def test_forward_shared_experts_with_gate_310():
69+
layer = _build_layer(_DummySharedExperts(with_gate=True))
70+
hidden_states = torch.randn(4, 8)
71+
output = layer._forward_shared_experts(hidden_states)
72+
expected = 0.5 * (hidden_states * 2.0 + 1.0)
73+
torch.testing.assert_close(output, expected)
74+
75+
76+
def test_forward_impl_with_shared_experts_returns_tuple_310():
77+
layer = _build_layer(_DummySharedExperts(with_gate=True))
78+
hidden_states = torch.randn(3, 8)
79+
router_logits = torch.randn(3, 8)
80+
routed_out = torch.randn(3, 8)
81+
82+
with patch.object(AscendFusedMoE310, "forward_impl", return_value=routed_out):
83+
shared_out, routed = layer.forward_impl(hidden_states, router_logits)
84+
85+
expected_shared = 0.5 * (hidden_states * 2.0 + 1.0)
86+
torch.testing.assert_close(shared_out, expected_shared)
87+
torch.testing.assert_close(routed, routed_out)
88+
89+
90+
def test_forward_impl_without_shared_experts_integration_310():
91+
layer = _build_layer(None)
92+
hidden_states = torch.randn(3, 8)
93+
assert layer._forward_shared_experts(hidden_states) is None
94+
95+
96+
def test_forward_impl_without_shared_experts_returns_routed_only_310():
97+
layer = _build_layer(None)
98+
hidden_states = torch.randn(3, 8)
99+
router_logits = torch.randn(3, 8)
100+
routed_out = torch.randn(3, 8)
101+
102+
with patch.object(AscendFusedMoE310, "forward_impl", return_value=routed_out):
103+
output = layer.forward_impl(hidden_states, router_logits)
104+
105+
torch.testing.assert_close(output, routed_out)
106+
107+
108+
def test_is_internal_router_is_false_310():
109+
layer = _build_layer(_DummySharedExperts(with_gate=True))
110+
assert layer.is_internal_router is False

vllm_ascend/_310p/fused_moe/fused_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ def __init__(
274274
self.use_overlapped = use_overlapped
275275
self.shared_expert_stream = None
276276
self._gate = gate
277+
# Recreate runner after shared_experts/gate are set so custom op dispatch
278+
# goes through moe_forward_shared.
279+
self.runner = self._init_runner()
280+
281+
@property
282+
def is_internal_router(self) -> bool:
283+
# 310P Ascend path expects router logits from the model forward path.
284+
return False
277285

278286
def forward(
279287
self,
@@ -298,9 +306,7 @@ def forward(
298306
def _forward_shared_experts(self, hidden_states: torch.Tensor):
299307
if self._shared_experts is None:
300308
return None
301-
part1_out = self._shared_experts_part1(hidden_states)
302-
shared_out = self._shared_experts_part2(hidden_states, part1_out)
303-
return shared_out
309+
return self._shared_experts(hidden_states)
304310

305311
def forward_impl( # type: ignore[override]
306312
self, hidden_states: torch.Tensor, router_logits: torch.Tensor

vllm_ascend/patch/worker/patch_qwen3_5_310.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm_ascend._310p.ops.fla.fused_gdn_gating import fused_gdn_gating_pytorch
2929
from vllm_ascend._310p.ops.fla.fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule_pytorch
3030
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
31-
from vllm_ascend.utils import enable_sp
31+
from vllm_ascend.utils import enable_sp, vllm_version_is
3232

3333

3434
class Ascend310Qwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
@@ -65,7 +65,7 @@ def _forward_core(
6565
non_spec_token_indx = attn_metadata.non_spec_token_indx
6666
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
6767
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
68-
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
68+
self_kv_cache = self.kv_cache[forward_context.virtual_engine if vllm_version_is("0.18.0") else 0]
6969
conv_state = self_kv_cache[0].transpose(-1, -2)
7070
ssm_state = self_kv_cache[1]
7171
num_actual_tokens = attn_metadata.num_actual_tokens

0 commit comments

Comments
 (0)