From 53ef6fe6fd6b3c566b98c709d4a11b32bff93ceb Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Thu, 26 Mar 2026 20:50:03 +0800 Subject: [PATCH 1/5] [310p] support shared experts path in fused MoE for qwen3.5 Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- .../multicard/test_moe_model_multicard.py | 27 ++++- .../singlecard/test_dense_model_singlecard.py | 27 ++++- .../fused_moe/test_shared_fused_moe_310.py | 107 ++++++++++++++++++ vllm_ascend/_310p/fused_moe/fused_moe.py | 12 +- 4 files changed, 158 insertions(+), 15 deletions(-) create mode 100644 tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py diff --git a/tests/e2e/310p/multicard/test_moe_model_multicard.py b/tests/e2e/310p/multicard/test_moe_model_multicard.py index f761acee58c..dc2f580df2e 100644 --- a/tests/e2e/310p/multicard/test_moe_model_multicard.py +++ b/tests/e2e/310p/multicard/test_moe_model_multicard.py @@ -49,10 +49,10 @@ def test_qwen3_moe_ep4_fp16(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_qwen3_moe_tp2_w8a8(): - example_prompts = [ - "Hello, my name is", - ] +def test_qwen3_moe_tp2_w8a8(): + example_prompts = [ + "Hello, my name is", + ] max_tokens = 5 with VllmRunner( "vllm-ascend/Qwen3-30B-A3B-W8A8", @@ -61,5 +61,20 @@ def test_qwen3_moe_tp2_w8a8(): dtype="float16", quantization="ascend", max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_qwen3_5_moe_tp4_fp16(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3.5-35B-A3B", + tensor_parallel_size=4, + enforce_eager=True, + dtype="float16", + max_model_len=16384, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/310p/singlecard/test_dense_model_singlecard.py b/tests/e2e/310p/singlecard/test_dense_model_singlecard.py index 26cc05a4ea2..32b1c454d92 100644 --- a/tests/e2e/310p/singlecard/test_dense_model_singlecard.py +++ b/tests/e2e/310p/singlecard/test_dense_model_singlecard.py @@ -33,10 +33,10 @@ def test_qwen3_dense_tp1_fp16(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_qwen3_dense_tp1_w8a8(): - example_prompts = [ - "Hello, my name is", - ] +def test_qwen3_dense_tp1_w8a8(): + example_prompts = [ + "Hello, my name is", + ] max_tokens = 5 with VllmRunner( "vllm-ascend/Qwen3-8B-W8A8", @@ -45,5 +45,20 @@ def test_qwen3_dense_tp1_w8a8(): dtype="float16", quantization="ascend", max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_qwen3_5_dense_tp1_fp16(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3.5-4B", + tensor_parallel_size=1, + enforce_eager=True, + dtype="float16", + max_model_len=16384, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py b/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py new file mode 100644 index 00000000000..acef761e4b7 --- /dev/null +++ b/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import torch +import torch.nn.functional as F + +from vllm_ascend._310p.fused_moe.fused_moe import ( + AscendFusedMoE310, + AscendSharedFusedMoE310, +) + + +class _DummyGate(torch.nn.Module): + def forward(self, hidden_states: torch.Tensor): + # Keep gate output deterministic: sigmoid(0)=0.5. + return torch.zeros( + hidden_states.shape[0], + 1, + dtype=hidden_states.dtype, + device=hidden_states.device, + ), None + + +class _DummySharedExperts(torch.nn.Module): + def __init__(self, with_gate: bool): + super().__init__() + self.expert_gate = _DummyGate() if with_gate else None + + def forward(self, hidden_states: torch.Tensor): + out = hidden_states * 2.0 + 1.0 + if self.expert_gate is not None: + gate_out, _ = self.expert_gate(hidden_states) + out = F.sigmoid(gate_out) * out + return out + + +def _build_layer(shared_experts: torch.nn.Module | None) -> AscendSharedFusedMoE310: + layer = AscendSharedFusedMoE310.__new__(AscendSharedFusedMoE310) + layer._shared_experts = shared_experts + return layer + + +def test_forward_shared_experts_without_gate_310(): + layer = _build_layer(_DummySharedExperts(with_gate=False)) + hidden_states = torch.randn(4, 8) + output = layer._forward_shared_experts(hidden_states) + expected = hidden_states * 2.0 + 1.0 + torch.testing.assert_close(output, expected) + + +def test_forward_shared_experts_with_gate_310(): + layer = _build_layer(_DummySharedExperts(with_gate=True)) + hidden_states = torch.randn(4, 8) + output = layer._forward_shared_experts(hidden_states) + expected = 0.5 * (hidden_states * 2.0 + 1.0) + torch.testing.assert_close(output, expected) + + +def test_forward_impl_with_shared_experts_returns_tuple_310(): + layer = _build_layer(_DummySharedExperts(with_gate=True)) + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 8) + routed_out = torch.randn(3, 8) + + with patch.object(AscendFusedMoE310, "forward_impl", return_value=routed_out): + shared_out, routed = layer.forward_impl(hidden_states, router_logits) + + expected_shared = 0.5 * (hidden_states * 2.0 + 1.0) + torch.testing.assert_close(shared_out, expected_shared) + torch.testing.assert_close(routed, routed_out) + + +def test_forward_impl_without_shared_experts_integration_310(): + layer = _build_layer(None) + hidden_states = torch.randn(3, 8) + assert layer._forward_shared_experts(hidden_states) is None + + +def test_forward_impl_without_shared_experts_returns_routed_only_310(): + layer = _build_layer(None) + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 8) + routed_out = torch.randn(3, 8) + + with patch.object(AscendFusedMoE310, "forward_impl", return_value=routed_out): + output = layer.forward_impl(hidden_states, router_logits) + + torch.testing.assert_close(output, routed_out) + + +def test_is_internal_router_is_false_310(): + layer = _build_layer(_DummySharedExperts(with_gate=True)) + assert layer.is_internal_router is False diff --git a/vllm_ascend/_310p/fused_moe/fused_moe.py b/vllm_ascend/_310p/fused_moe/fused_moe.py index 9de035cd469..6ff3fb15be2 100644 --- a/vllm_ascend/_310p/fused_moe/fused_moe.py +++ b/vllm_ascend/_310p/fused_moe/fused_moe.py @@ -274,6 +274,14 @@ def __init__( self.use_overlapped = use_overlapped self.shared_expert_stream = None self._gate = gate + # Recreate runner after shared_experts/gate are set so custom op dispatch + # goes through moe_forward_shared. + self.runner = self._init_runner() + + @property + def is_internal_router(self) -> bool: + # 310P Ascend path expects router logits from the model forward path. + return False def forward( self, @@ -298,9 +306,7 @@ def forward( def _forward_shared_experts(self, hidden_states: torch.Tensor): if self._shared_experts is None: return None - part1_out = self._shared_experts_part1(hidden_states) - shared_out = self._shared_experts_part2(hidden_states, part1_out) - return shared_out + return self._shared_experts(hidden_states) def forward_impl( # type: ignore[override] self, hidden_states: torch.Tensor, router_logits: torch.Tensor From f6cee39bb35642e2a991ed659a84a43f83a0c59f Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 27 Mar 2026 16:05:19 +0800 Subject: [PATCH 2/5] cleancode Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- .../multicard/test_moe_model_multicard.py | 42 +++++++++---------- .../singlecard/test_dense_model_singlecard.py | 42 +++++++++---------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/tests/e2e/310p/multicard/test_moe_model_multicard.py b/tests/e2e/310p/multicard/test_moe_model_multicard.py index dc2f580df2e..eb23ae61fd1 100644 --- a/tests/e2e/310p/multicard/test_moe_model_multicard.py +++ b/tests/e2e/310p/multicard/test_moe_model_multicard.py @@ -49,10 +49,10 @@ def test_qwen3_moe_ep4_fp16(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_qwen3_moe_tp2_w8a8(): - example_prompts = [ - "Hello, my name is", - ] +def test_qwen3_moe_tp2_w8a8(): + example_prompts = [ + "Hello, my name is", + ] max_tokens = 5 with VllmRunner( "vllm-ascend/Qwen3-30B-A3B-W8A8", @@ -61,20 +61,20 @@ def test_qwen3_moe_tp2_w8a8(): dtype="float16", quantization="ascend", max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - - -def test_qwen3_5_moe_tp4_fp16(): - example_prompts = [ - "Hello, my name is", - ] - max_tokens = 5 - with VllmRunner( - "Qwen/Qwen3.5-35B-A3B", - tensor_parallel_size=4, - enforce_eager=True, - dtype="float16", - max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_qwen3_5_moe_tp4_fp16(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3.5-35B-A3B", + tensor_parallel_size=4, + enforce_eager=True, + dtype="float16", + max_model_len=16384, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/310p/singlecard/test_dense_model_singlecard.py b/tests/e2e/310p/singlecard/test_dense_model_singlecard.py index 32b1c454d92..cc13c3ecaf7 100644 --- a/tests/e2e/310p/singlecard/test_dense_model_singlecard.py +++ b/tests/e2e/310p/singlecard/test_dense_model_singlecard.py @@ -33,10 +33,10 @@ def test_qwen3_dense_tp1_fp16(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_qwen3_dense_tp1_w8a8(): - example_prompts = [ - "Hello, my name is", - ] +def test_qwen3_dense_tp1_w8a8(): + example_prompts = [ + "Hello, my name is", + ] max_tokens = 5 with VllmRunner( "vllm-ascend/Qwen3-8B-W8A8", @@ -45,20 +45,20 @@ def test_qwen3_dense_tp1_w8a8(): dtype="float16", quantization="ascend", max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - - -def test_qwen3_5_dense_tp1_fp16(): - example_prompts = [ - "Hello, my name is", - ] - max_tokens = 5 - with VllmRunner( - "Qwen/Qwen3.5-4B", - tensor_parallel_size=1, - enforce_eager=True, - dtype="float16", - max_model_len=16384, - ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_qwen3_5_dense_tp1_fp16(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3.5-4B", + tensor_parallel_size=1, + enforce_eager=True, + dtype="float16", + max_model_len=16384, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) From 23943ea775aea3d9e856b7d052e28062ea9e8bd6 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 27 Mar 2026 16:35:30 +0800 Subject: [PATCH 3/5] ut pass Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py b/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py index acef761e4b7..545dfe0d65d 100644 --- a/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py +++ b/tests/ut/_310p/fused_moe/test_shared_fused_moe_310.py @@ -50,6 +50,9 @@ def forward(self, hidden_states: torch.Tensor): def _build_layer(shared_experts: torch.nn.Module | None) -> AscendSharedFusedMoE310: layer = AscendSharedFusedMoE310.__new__(AscendSharedFusedMoE310) + # The test bypasses full layer init with __new__, so we must initialize + # nn.Module internals before assigning child modules. + torch.nn.Module.__init__(layer) layer._shared_experts = shared_experts return layer From a24559a17b8a16b472c81cd172bc88824520d128 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sun, 29 Mar 2026 21:10:01 +0800 Subject: [PATCH 4/5] trigger Signed-off-by: Tflowers-0129 <2906339855@qq.com> From 0d544ea9e4895bd039e42ac492a427a2f2b81f0a Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Mon, 30 Mar 2026 14:48:51 +0800 Subject: [PATCH 5/5] [bugfix] align ascend patch Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/patch/worker/patch_qwen3_5_310.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_qwen3_5_310.py b/vllm_ascend/patch/worker/patch_qwen3_5_310.py index 3326bff6a3b..d034fb5c1df 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5_310.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5_310.py @@ -28,7 +28,7 @@ from vllm_ascend._310p.ops.fla.fused_gdn_gating import fused_gdn_gating_pytorch from vllm_ascend._310p.ops.fla.fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule_pytorch from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector -from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import enable_sp, vllm_version_is class Ascend310Qwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): @@ -65,7 +65,7 @@ def _forward_core( non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[forward_context.virtual_engine if vllm_version_is("0.18.0") else 0] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens