Skip to content

Commit 6a4ec18

Browse files
authored
[Qwen-moe] Remove the minor operation arange (#2373)
### What this PR does / why we need it? Integrate the arange operator to reduce the time spent and improve performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@56dcf4e --------- Signed-off-by: s30076806 <[email protected]>
1 parent 358ba68 commit 6a4ec18

File tree

9 files changed

+81
-80
lines changed

9 files changed

+81
-80
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,15 @@ def test_fused_experts(
9292
score = torch.softmax(score, dim=-1, dtype=dtype)
9393
topk_weights, topk_ids = torch.topk(score, topk)
9494
topk_ids = topk_ids.to(torch.int32)
95-
96-
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
95+
row_idx = (torch.arange(
96+
0,
97+
m * topk,
98+
device=device,
99+
dtype=torch.int32,
100+
).view(topk, -1).permute(1, 0).contiguous())
101+
102+
output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk,
103+
e_map)
97104
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
98105
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
99106
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
@@ -148,7 +155,7 @@ def test_select_experts(
148155
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
149156
x)
150157

151-
topk_weights, topk_ids = select_experts(
158+
topk_weights, topk_ids, row_idx = select_experts(
152159
hidden_states=hidden_states,
153160
router_logits=router_logits,
154161
top_k=topk,
@@ -169,6 +176,7 @@ def test_select_experts(
169176
assert topk_weights.shape == (m, topk)
170177
assert topk_ids.shape == (m, topk)
171178
assert topk_ids.dtype == torch.int32
179+
assert row_idx.shape == (m, topk)
172180

173181

174182
@pytest.mark.parametrize("device", DEVICE)

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
405405

406406
x = torch.randn(8, 2)
407407
router_logits = torch.randn(8, 2)
408-
topk_weights, topk_ids = select_experts(
408+
topk_weights, topk_ids, _ = select_experts(
409409
hidden_states=x,
410410
router_logits=router_logits,
411411
top_k=2,

tests/ut/quantization/test_w8a8.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -719,25 +719,25 @@ def setUp(self):
719719
def test_softmax_scoring(self):
720720
"""Test softmax scoring function"""
721721

722-
weights, ids = select_experts(hidden_states=self.hidden_states,
723-
router_logits=self.router_logits,
724-
top_k=self.top_k,
725-
use_grouped_topk=False,
726-
renormalize=False,
727-
scoring_func="softmax")
722+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
723+
router_logits=self.router_logits,
724+
top_k=self.top_k,
725+
use_grouped_topk=False,
726+
renormalize=False,
727+
scoring_func="softmax")
728728

729729
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
730730
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
731731

732732
def test_sigmoid_scoring(self):
733733
"""Test sigmoid scoring function"""
734734

735-
weights, ids = select_experts(hidden_states=self.hidden_states,
736-
router_logits=self.router_logits,
737-
top_k=self.top_k,
738-
use_grouped_topk=False,
739-
renormalize=False,
740-
scoring_func="sigmoid")
735+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
736+
router_logits=self.router_logits,
737+
top_k=self.top_k,
738+
use_grouped_topk=False,
739+
renormalize=False,
740+
scoring_func="sigmoid")
741741

742742
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
743743
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -760,13 +760,13 @@ def test_grouped_topk(self, mock_topk):
760760
self.top_k,
761761
dtype=torch.long))
762762

763-
weights, ids = select_experts(hidden_states=self.hidden_states,
764-
router_logits=self.router_logits,
765-
top_k=self.top_k,
766-
use_grouped_topk=True,
767-
renormalize=False,
768-
topk_group=4,
769-
num_expert_group=2)
763+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
764+
router_logits=self.router_logits,
765+
top_k=self.top_k,
766+
use_grouped_topk=True,
767+
renormalize=False,
768+
topk_group=4,
769+
num_expert_group=2)
770770

771771
mock_topk.assert_called()
772772
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
780780
self.num_experts)
781781

782782
e_score_correction_bias = torch.randn(self.num_experts)
783-
weights, ids = select_experts(
783+
weights, ids, _ = select_experts(
784784
hidden_states=self.hidden_states,
785785
router_logits=self.router_logits,
786786
top_k=self.top_k,
@@ -803,7 +803,7 @@ def test_custom_routing_function(self):
803803
self.top_k,
804804
dtype=torch.int32))
805805

806-
weights, ids = select_experts(
806+
weights, ids, _ = select_experts(
807807
hidden_states=self.hidden_states,
808808
router_logits=self.router_logits,
809809
top_k=self.top_k,
@@ -824,7 +824,7 @@ def test_renormalize(self, mock_topk):
824824
self.top_k,
825825
dtype=torch.long))
826826

827-
weights, _ = select_experts(
827+
weights, ids, _ = select_experts(
828828
hidden_states=self.hidden_states,
829829
router_logits=self.router_logits,
830830
top_k=self.top_k,
@@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
844844
self.top_k,
845845
dtype=torch.long))
846846

847-
weights, ids = select_experts(
847+
weights, ids, _ = select_experts(
848848
hidden_states=self.hidden_states,
849849
router_logits=self.router_logits,
850850
top_k=self.top_k,

tests/ut/quantization/test_w8a8_dynamic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
5555
torch.randn(self.num_tokens),
5656
)
5757
mock_moe_finalize_routing.return_value = self.placeholder
58+
row_idx_len = self.num_tokens * 8
59+
row_idx = (torch.arange(
60+
0,
61+
row_idx_len,
62+
dtype=torch.int32,
63+
).view(8, -1).permute(1, 0).contiguous())
5864

5965
result = fused_experts_with_all2all(
6066
hidden_states=self.placeholder,
@@ -64,6 +70,7 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
6470
w2_scale=self.placeholder,
6571
topk_weights=self.placeholder,
6672
topk_ids=self.placeholder,
73+
row_idx=row_idx,
6774
top_k=8,
6875
expert_map=expert_map,
6976
ep_group=ep_group,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def forward_oot(
130130
logical_to_physical_map: Optional[torch.Tensor] = None,
131131
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
132132

133-
topk_weights, topk_ids = select_experts(
133+
topk_weights, topk_ids, _ = select_experts(
134134
hidden_states=x,
135135
router_logits=router_logits,
136136
top_k=top_k,

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def fused_experts_with_all2all(
326326
w2: torch.Tensor,
327327
topk_weights: torch.Tensor,
328328
topk_ids: torch.Tensor,
329+
row_idx: torch.Tensor,
329330
top_k: int,
330331
expert_map: torch.Tensor = None,
331332
ep_group: GroupCoordinator = None,
@@ -336,17 +337,10 @@ def fused_experts_with_all2all(
336337

337338
num_tokens, _ = hidden_states.shape
338339
num_experts = w1.shape[0]
339-
device = hidden_states.device
340340

341341
if expert_map is not None:
342342
global_num_experts = len(expert_map)
343343
local_num_experts = global_num_experts // ep_group.world_size
344-
row_idx_len = num_tokens * top_k
345-
row_idx = (torch.arange(0,
346-
row_idx_len,
347-
dtype=torch.int32,
348-
device=device).view(top_k, -1).permute(
349-
1, 0).contiguous())
350344
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
351345
hidden_states,
352346
row_idx=row_idx,
@@ -380,12 +374,6 @@ def fused_experts_with_all2all(
380374

381375
hidden_states = hidden_states[sorted_idx]
382376
else:
383-
row_idx_len = num_tokens * top_k
384-
row_idx = torch.arange(0,
385-
row_idx_len,
386-
dtype=torch.int32,
387-
device=topk_weights.device).view(
388-
top_k, -1).permute(1, 0).contiguous()
389377
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
390378
hidden_states,
391379
row_idx=row_idx,
@@ -459,6 +447,7 @@ def fused_experts_with_all2all_buffer(
459447
w2: torch.Tensor,
460448
topk_weights: torch.Tensor,
461449
topk_ids: torch.Tensor,
450+
row_idx: torch.Tensor,
462451
top_k: int,
463452
max_model_len: int,
464453
global_batch_size: int,
@@ -470,14 +459,10 @@ def fused_experts_with_all2all_buffer(
470459
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
471460

472461
num_tokens, _ = hidden_states.shape
473-
device = hidden_states.device
474462

475463
global_num_experts = len(expert_map)
476464
local_num_experts = global_num_experts // ep_group.world_size
477465
row_idx_len = num_tokens * top_k
478-
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
479-
device=device).view(top_k,
480-
-1).permute(1, 0).contiguous())
481466
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
482467
hidden_states,
483468
row_idx=row_idx,
@@ -690,6 +675,7 @@ def fused_experts(
690675
w2: torch.Tensor,
691676
topk_weights: torch.Tensor,
692677
topk_ids: torch.Tensor,
678+
row_idx: torch.Tensor,
693679
top_k: int,
694680
expert_map: torch.Tensor = None,
695681
apply_router_weight_on_input: bool = False,
@@ -781,12 +767,6 @@ def fused_experts(
781767
# Rearrange hidden_states
782768
sorted_hidden_states = hidden_states[sorted_token_indices]
783769
else:
784-
row_idx_len = num_tokens * top_k
785-
row_idx = (torch.arange(0,
786-
row_idx_len,
787-
dtype=torch.int32,
788-
device=device).view(top_k, -1).permute(
789-
1, 0).contiguous())
790770
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
791771
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
792772
hidden_states,
@@ -908,7 +888,7 @@ def apply(
908888
**kwargs,
909889
) -> torch.Tensor:
910890

911-
topk_weights, topk_ids = select_experts(
891+
topk_weights, topk_ids, row_idx = select_experts(
912892
hidden_states=x,
913893
router_logits=router_logits,
914894
top_k=top_k,
@@ -952,6 +932,7 @@ def apply(
952932
w2=layer.w2_weight,
953933
topk_weights=topk_weights,
954934
topk_ids=topk_ids,
935+
row_idx=row_idx,
955936
top_k=top_k,
956937
expert_map=expert_map)
957938
elif MOE_ALL2ALL_BUFFER:
@@ -961,6 +942,7 @@ def apply(
961942
w2=layer.w2_weight,
962943
topk_weights=topk_weights,
963944
topk_ids=topk_ids,
945+
row_idx=row_idx,
964946
top_k=top_k,
965947
max_model_len=self.max_model_len,
966948
global_batch_size=self.global_batch_size,
@@ -982,6 +964,7 @@ def apply(
982964
w2=layer.w2_weight,
983965
topk_weights=topk_weights,
984966
topk_ids=topk_ids,
967+
row_idx=row_idx,
985968
top_k=top_k,
986969
expert_map=expert_map,
987970
ep_group=get_ep_group())

vllm_ascend/ops/layers/experts_selector.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@
2020
import torch_npu
2121

2222

23+
def return_row_idx(hidden_states, top_k):
24+
num_tokens = hidden_states.shape[0]
25+
row_idx_len = num_tokens * top_k
26+
row_idx = (torch.arange(0,
27+
row_idx_len,
28+
dtype=torch.int32,
29+
device=hidden_states.device).view(
30+
top_k, -1).permute(1, 0).contiguous())
31+
return row_idx
32+
33+
2334
def select_experts(hidden_states: torch.Tensor,
2435
router_logits: torch.Tensor,
2536
top_k: int,
@@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor,
5667
topk_ids: selected expert IDs of shape (num_tokens, top_k).
5768
"""
5869

59-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
70+
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
71+
hidden_states=hidden_states,
6072
router_logits=router_logits,
6173
top_k=top_k,
6274
use_grouped_topk=use_grouped_topk,
@@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor,
8395
e_score_correction_bias=e_score_correction_bias,
8496
global_num_experts=global_num_experts,
8597
)
86-
return topk_weights, topk_ids
98+
if row_idx is None:
99+
row_idx = return_row_idx(hidden_states, top_k)
100+
return topk_weights, topk_ids, row_idx
87101

88102

89103
def _native_grouped_topk(
@@ -156,6 +170,7 @@ def _select_expert_use_group_topk(
156170

157171

158172
def _select_experts_with_fusion_ops(
173+
hidden_states: torch.Tensor,
159174
router_logits: torch.Tensor,
160175
top_k: int,
161176
use_grouped_topk: bool,
@@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops(
168183
global_num_experts: int = -1,
169184
is_unquantized: bool = False):
170185

171-
topk_weights, topk_ids = None, None
186+
topk_weights, topk_ids, row_idx = None, None, None
172187
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
173188
is_deepseek_v3_r1 = global_num_experts == 256
174189
if is_deepseek_v3_r1:
@@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops(
186201
# y2_flag=False, # old api; should the third output be output
187202
routed_scaling_factor=1,
188203
eps=float(1e-20))
189-
204+
row_idx = return_row_idx(hidden_states, top_k)
190205
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
191-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
206+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
192207
x=router_logits, finished=None, k=top_k)
193208
topk_ids = topk_ids.to(torch.int32)
194209
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
195210

196-
return topk_weights, topk_ids
211+
return topk_weights, topk_ids, row_idx
197212

198213

199214
def _native_select_experts(

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def apply(
268268
1] == global_num_experts, "Number of global experts mismatch"
269269

270270
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
271-
topk_weights, topk_ids = select_experts(
271+
topk_weights, topk_ids, row_idx = select_experts(
272272
hidden_states=x,
273273
router_logits=router_logits,
274274
top_k=top_k,
@@ -334,6 +334,7 @@ def apply(
334334
w2_scale_bias=layer.w2_scale_bias,
335335
topk_weights=topk_weights,
336336
topk_ids=topk_ids,
337+
row_idx=row_idx,
337338
top_k=top_k,
338339
expert_map=expert_map,
339340
ep_group=self.ep_group,

0 commit comments

Comments
 (0)