Skip to content

Commit 14497b7

Browse files
florenceCHflorenceCH
andauthored
Remove qwen3 moe MC2 cumsum & cast (#3126)
What this PR does / why we need it? The Qwen3 moe MC2 graph currently has two redundant computational operator implementations. After npu_moe_distribute_dispatch_v2, the cumsum and cast operations have been added. By using expert_token_nums_type=0 and not converting weight_scale to float32, these two operators can be eliminated, thereby improving inference performance. Does this PR introduce any user-facing change? No How was this patch tested? No need vLLM version: v0.10.2 vLLM main: vllm-project/vllm@f225ea7 - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 --------- Signed-off-by: florenceCH <[email protected]> Co-authored-by: florenceCH <[email protected]>
1 parent 2930e4a commit 14497b7

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_token_permutation_dispatch(self):
9898
self.row_idx, expert_map)
9999
mock_dispatch.assert_called_once()
100100
self.assertEqual(output["group_list_type"],
101-
1) # group_list_type == 1
101+
0) # group_list_type == 0
102102

103103
def test_token_dispatch_with_shared_experts_and_quant(self):
104104
self.shared_experts = MagicMock()

vllm_ascend/ops/moe/moe_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7979

8080
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
8181
if w1_scale_bias is None and is_mc2:
82-
if w1_scale.dtype != torch.float32:
83-
w1_scale = w1_scale.to(torch.float32)
8482
if fusion:
8583
# gmm1: gate_up_proj & act_fn: swiglu
8684
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
@@ -90,6 +88,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
9088
weight_scale=w1_scale,
9189
x_scale=pertoken_scale)
9290
else:
91+
if w1_scale.dtype != torch.float32:
92+
w1_scale = w1_scale.to(torch.float32)
9393
# gmm1: gate_up_proj
9494
hidden_states = torch_npu.npu_grouped_matmul(
9595
x=[hidden_states],

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def get_dispatch_mc2_kwargs(
133133
"shared_expert_rank_num": 0,
134134
"moe_expert_num": moe_expert_num,
135135
"global_bs": 0,
136+
"expert_token_nums_type": 0,
136137
}
137138

138139
stage1_kwargs = {
@@ -204,7 +205,7 @@ def token_dispatch(self,
204205
if shared_experts is not None:
205206
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
206207
self.shared_act = shared_experts.act_fn(shared_gate_up)
207-
group_list_type = 1
208+
group_list_type = 0
208209
return {
209210
"group_list_type": group_list_type,
210211
"hidden_states": expand_x,

0 commit comments

Comments
 (0)