Skip to content

Commit dceef08

Browse files
authored
[main] remove torch.cat and replace it by List[0] (#2153)
### What this PR does / why we need it? torch_npu.npu_grouped_matmul: https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md According to the document, when `split_item` is 2 or 3, `torch_npu.npu_grouped_matmul` will return a list which has one element. Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is unnecessary. ### Does this PR introduce _any_ user-facing change? not involved ### How was this patch tested? ut and e2e covered: `tests/ut/ops/test_fused_ops.py`, `tests/e2e/singlecard/ops/test_fused_moe.py` **performance**: (qwen3 30B, 2k->20k) base: Total Token throughput (tok/s): 667.76 remove cat: Total Token throughput (tok/s): 680.82 - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@fa00c5d Signed-off-by: huangxialu <[email protected]>
1 parent b2598c3 commit dceef08

File tree

2 files changed

+13
-27
lines changed

2 files changed

+13
-27
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture):
112112
torch.randn(16, 2)
113113
)), \
114114
patch("torch_npu.npu_grouped_matmul", return_value=(
115-
(torch.randn(8, 2), torch.randn(8, 2))
115+
[torch.randn(16, 2)]
116116
)), \
117117
patch("torch_npu.npu_swiglu", return_value=(
118118
torch.randn(16, 2)

vllm_ascend/ops/fused_moe.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,9 @@ def fused_experts_with_mc2(
205205
group_list_type=1,
206206
group_type=0,
207207
group_list=group_list,
208-
)
208+
)[0]
209209

210-
# TODO: Remove this in the future.
211-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
212-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
210+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
213211

214212
w2 = w2.transpose(1, 2)
215213
down_out_list = torch_npu.npu_grouped_matmul(
@@ -219,9 +217,7 @@ def fused_experts_with_mc2(
219217
group_list_type=1,
220218
group_type=0,
221219
group_list=group_list,
222-
)
223-
224-
down_out_list = torch.cat(down_out_list, dim=0)
220+
)[0]
225221

226222
# moeCombine
227223
kwargs_mc2 = {
@@ -312,9 +308,8 @@ def apply_mlp(
312308
group_list_type=group_list_type,
313309
group_type=0,
314310
group_list=group_list,
315-
)
311+
)[0]
316312

317-
hidden_states = torch.cat(hidden_states, dim=0)
318313
hidden_states = torch_npu.npu_swiglu(hidden_states)
319314

320315
w2 = w2.transpose(1, 2)
@@ -325,9 +320,8 @@ def apply_mlp(
325320
group_list_type=group_list_type,
326321
group_type=0,
327322
group_list=group_list,
328-
)
323+
)[0]
329324

330-
hidden_states = torch.cat(hidden_states, dim=0)
331325
return hidden_states
332326

333327

@@ -417,23 +411,19 @@ def fused_experts_with_all2all(
417411
group_list_type=0,
418412
group_type=0,
419413
group_list=expert_tokens,
420-
)
414+
)[0]
421415

422-
# TODO: Remove this in the future.
423-
hidden_states = torch.cat(gate_up_out_list, dim=0)
424-
hidden_states = torch_npu.npu_swiglu(hidden_states)
416+
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
425417

426418
w2 = w2.transpose(1, 2)
427-
down_out_list = torch_npu.npu_grouped_matmul(
419+
hidden_states = torch_npu.npu_grouped_matmul(
428420
x=[hidden_states],
429421
weight=[w2],
430422
split_item=2,
431423
group_list_type=0,
432424
group_type=0,
433425
group_list=expert_tokens,
434-
)
435-
436-
hidden_states = torch.cat(down_out_list, dim=0)
426+
)[0]
437427

438428
if expert_map is not None:
439429
resorted_idx = torch.argsort(sorted_idx)
@@ -823,11 +813,9 @@ def fused_experts(
823813
group_list_type=0,
824814
group_type=0,
825815
group_list=expert_tokens,
826-
)
816+
)[0]
827817

828-
# TODO: Remove this in the future.
829-
gate_up_out = torch.cat(gate_up_out_list, dim=0)
830-
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
818+
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
831819

832820
w2 = w2.transpose(1, 2)
833821
down_out_list = torch_npu.npu_grouped_matmul(
@@ -837,9 +825,7 @@ def fused_experts(
837825
group_list_type=0,
838826
group_type=0,
839827
group_list=expert_tokens,
840-
)
841-
842-
down_out_list = torch.cat(down_out_list, dim=0)
828+
)[0]
843829

844830
if expert_map is not None:
845831
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)

0 commit comments

Comments
 (0)