Skip to content

Commit c466592

Browse files
committed
[Qwen-moe] Remove the minor operation arange
Signed-off-by: s30076806 <[email protected]>
1 parent e14f2ef commit c466592

File tree

9 files changed

+62
-78
lines changed

9 files changed

+62
-78
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_select_experts(
148148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
149149
x)
150150

151-
topk_weights, topk_ids = select_experts(
151+
topk_weights, topk_ids, _ = select_experts(
152152
hidden_states=hidden_states,
153153
router_logits=router_logits,
154154
top_k=topk,

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
400400

401401
x = torch.randn(8, 2)
402402
router_logits = torch.randn(8, 2)
403-
topk_weights, topk_ids = select_experts(
403+
topk_weights, topk_ids, _ = select_experts(
404404
hidden_states=x,
405405
router_logits=router_logits,
406406
top_k=2,

tests/ut/quantization/test_w8a8.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -719,25 +719,25 @@ def setUp(self):
719719
def test_softmax_scoring(self):
720720
"""Test softmax scoring function"""
721721

722-
weights, ids = select_experts(hidden_states=self.hidden_states,
723-
router_logits=self.router_logits,
724-
top_k=self.top_k,
725-
use_grouped_topk=False,
726-
renormalize=False,
727-
scoring_func="softmax")
722+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
723+
router_logits=self.router_logits,
724+
top_k=self.top_k,
725+
use_grouped_topk=False,
726+
renormalize=False,
727+
scoring_func="softmax")
728728

729729
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
730730
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
731731

732732
def test_sigmoid_scoring(self):
733733
"""Test sigmoid scoring function"""
734734

735-
weights, ids = select_experts(hidden_states=self.hidden_states,
736-
router_logits=self.router_logits,
737-
top_k=self.top_k,
738-
use_grouped_topk=False,
739-
renormalize=False,
740-
scoring_func="sigmoid")
735+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
736+
router_logits=self.router_logits,
737+
top_k=self.top_k,
738+
use_grouped_topk=False,
739+
renormalize=False,
740+
scoring_func="sigmoid")
741741

742742
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
743743
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -760,13 +760,13 @@ def test_grouped_topk(self, mock_topk):
760760
self.top_k,
761761
dtype=torch.long))
762762

763-
weights, ids = select_experts(hidden_states=self.hidden_states,
764-
router_logits=self.router_logits,
765-
top_k=self.top_k,
766-
use_grouped_topk=True,
767-
renormalize=False,
768-
topk_group=4,
769-
num_expert_group=2)
763+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
764+
router_logits=self.router_logits,
765+
top_k=self.top_k,
766+
use_grouped_topk=True,
767+
renormalize=False,
768+
topk_group=4,
769+
num_expert_group=2)
770770

771771
mock_topk.assert_called()
772772
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
780780
self.num_experts)
781781

782782
e_score_correction_bias = torch.randn(self.num_experts)
783-
weights, ids = select_experts(
783+
weights, ids, _ = select_experts(
784784
hidden_states=self.hidden_states,
785785
router_logits=self.router_logits,
786786
top_k=self.top_k,
@@ -803,7 +803,7 @@ def test_custom_routing_function(self):
803803
self.top_k,
804804
dtype=torch.int32))
805805

806-
weights, ids = select_experts(
806+
weights, ids, _ = select_experts(
807807
hidden_states=self.hidden_states,
808808
router_logits=self.router_logits,
809809
top_k=self.top_k,
@@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
844844
self.top_k,
845845
dtype=torch.long))
846846

847-
weights, ids = select_experts(
847+
weights, ids, _ = select_experts(
848848
hidden_states=self.hidden_states,
849849
router_logits=self.router_logits,
850850
top_k=self.top_k,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def forward_oot(
6868
logical_to_physical_map: Optional[torch.Tensor] = None,
6969
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
7070

71-
topk_weights, topk_ids = select_experts(
71+
topk_weights, topk_ids, _ = select_experts(
7272
hidden_states=x,
7373
router_logits=router_logits,
7474
top_k=top_k,

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def fused_experts_with_all2all(
391391
w2: torch.Tensor,
392392
topk_weights: torch.Tensor,
393393
topk_ids: torch.Tensor,
394+
row_idx: torch.Tensor,
394395
top_k: int,
395396
expert_map: torch.Tensor = None,
396397
ep_group: GroupCoordinator = None,
@@ -401,17 +402,10 @@ def fused_experts_with_all2all(
401402

402403
num_tokens, _ = hidden_states.shape
403404
num_experts = w1.shape[0]
404-
device = hidden_states.device
405405

406406
if expert_map is not None:
407407
global_num_experts = len(expert_map)
408408
local_num_experts = global_num_experts // ep_group.world_size
409-
row_idx_len = num_tokens * top_k
410-
row_idx = (torch.arange(0,
411-
row_idx_len,
412-
dtype=torch.int32,
413-
device=device).view(top_k, -1).permute(
414-
1, 0).contiguous())
415409
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
416410
hidden_states,
417411
row_idx=row_idx,
@@ -445,12 +439,6 @@ def fused_experts_with_all2all(
445439

446440
hidden_states = hidden_states[sorted_idx]
447441
else:
448-
row_idx_len = num_tokens * top_k
449-
row_idx = torch.arange(0,
450-
row_idx_len,
451-
dtype=torch.int32,
452-
device=topk_weights.device).view(
453-
top_k, -1).permute(1, 0).contiguous()
454442
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
455443
hidden_states,
456444
row_idx=row_idx,
@@ -524,6 +512,7 @@ def fused_experts_with_all2all_buffer(
524512
w2: torch.Tensor,
525513
topk_weights: torch.Tensor,
526514
topk_ids: torch.Tensor,
515+
row_idx: torch.Tensor,
527516
top_k: int,
528517
max_model_len: int,
529518
global_batch_size: int,
@@ -535,14 +524,10 @@ def fused_experts_with_all2all_buffer(
535524
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
536525

537526
num_tokens, _ = hidden_states.shape
538-
device = hidden_states.device
539527

540528
global_num_experts = len(expert_map)
541529
local_num_experts = global_num_experts // ep_group.world_size
542530
row_idx_len = num_tokens * top_k
543-
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
544-
device=device).view(top_k,
545-
-1).permute(1, 0).contiguous())
546531
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
547532
hidden_states,
548533
row_idx=row_idx,
@@ -755,6 +740,7 @@ def fused_experts(
755740
w2: torch.Tensor,
756741
topk_weights: torch.Tensor,
757742
topk_ids: torch.Tensor,
743+
row_idx: torch.Tensor,
758744
top_k: int,
759745
expert_map: torch.Tensor = None,
760746
apply_router_weight_on_input: bool = False,
@@ -846,12 +832,6 @@ def fused_experts(
846832
# Rearrange hidden_states
847833
sorted_hidden_states = hidden_states[sorted_token_indices]
848834
else:
849-
row_idx_len = num_tokens * top_k
850-
row_idx = (torch.arange(0,
851-
row_idx_len,
852-
dtype=torch.int32,
853-
device=device).view(top_k, -1).permute(
854-
1, 0).contiguous())
855835
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
856836
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
857837
hidden_states,
@@ -975,7 +955,7 @@ def apply(
975955
**kwargs,
976956
) -> torch.Tensor:
977957

978-
topk_weights, topk_ids = select_experts(
958+
topk_weights, topk_ids, row_idx = select_experts(
979959
hidden_states=x,
980960
router_logits=router_logits,
981961
top_k=top_k,
@@ -1019,6 +999,7 @@ def apply(
1019999
w2=layer.w2_weight,
10201000
topk_weights=topk_weights,
10211001
topk_ids=topk_ids,
1002+
row_idx=row_idx,
10221003
top_k=top_k,
10231004
expert_map=expert_map)
10241005
elif MOE_ALL2ALL_BUFFER:
@@ -1028,6 +1009,7 @@ def apply(
10281009
w2=layer.w2_weight,
10291010
topk_weights=topk_weights,
10301011
topk_ids=topk_ids,
1012+
row_idx=row_idx,
10311013
top_k=top_k,
10321014
max_model_len=self.max_model_len,
10331015
global_batch_size=self.global_batch_size,
@@ -1049,6 +1031,7 @@ def apply(
10491031
w2=layer.w2_weight,
10501032
topk_weights=topk_weights,
10511033
topk_ids=topk_ids,
1034+
row_idx=row_idx,
10521035
top_k=top_k,
10531036
expert_map=expert_map,
10541037
ep_group=get_ep_group())

vllm_ascend/ops/layers/experts_selector.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@
2020
import torch_npu
2121

2222

23+
def return_row_idx(hidden_states, top_k):
24+
num_tokens, _ = hidden_states.shape
25+
row_idx_len = num_tokens * top_k
26+
row_idx = (torch.arange(0,
27+
row_idx_len,
28+
dtype=torch.int32,
29+
device=hidden_states.device).view(
30+
top_k, -1).permute(1, 0).contiguous())
31+
return row_idx
32+
33+
2334
def select_experts(hidden_states: torch.Tensor,
2435
router_logits: torch.Tensor,
2536
top_k: int,
@@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor,
5667
topk_ids: selected expert IDs of shape (num_tokens, top_k).
5768
"""
5869

59-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
70+
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
71+
hidden_states=hidden_states,
6072
router_logits=router_logits,
6173
top_k=top_k,
6274
use_grouped_topk=use_grouped_topk,
@@ -83,7 +95,7 @@ def select_experts(hidden_states: torch.Tensor,
8395
e_score_correction_bias=e_score_correction_bias,
8496
global_num_experts=global_num_experts,
8597
)
86-
return topk_weights, topk_ids
98+
return topk_weights, topk_ids, row_idx
8799

88100

89101
def _native_grouped_topk(
@@ -156,6 +168,7 @@ def _select_expert_use_group_topk(
156168

157169

158170
def _select_experts_with_fusion_ops(
171+
hidden_states: torch.Tensor,
159172
router_logits: torch.Tensor,
160173
top_k: int,
161174
use_grouped_topk: bool,
@@ -168,7 +181,7 @@ def _select_experts_with_fusion_ops(
168181
global_num_experts: int = -1,
169182
is_unquantized: bool = False):
170183

171-
topk_weights, topk_ids = None, None
184+
topk_weights, topk_ids, row_idx = None, None, None
172185
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
173186
is_deepseek_v3_r1 = global_num_experts == 256
174187
if is_deepseek_v3_r1:
@@ -186,14 +199,14 @@ def _select_experts_with_fusion_ops(
186199
# y2_flag=False, # old api; should the third output be output
187200
routed_scaling_factor=1,
188201
eps=float(1e-20))
189-
202+
row_idx = return_row_idx(hidden_states, top_k)
190203
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
191-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
204+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
192205
x=router_logits, finished=None, k=top_k)
193206
topk_ids = topk_ids.to(torch.int32)
194207
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
195208

196-
return topk_weights, topk_ids
209+
return topk_weights, topk_ids, row_idx
197210

198211

199212
def _native_select_experts(

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def apply(
245245
1] == global_num_experts, "Number of global experts mismatch"
246246

247247
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
248-
topk_weights, topk_ids = select_experts(
248+
topk_weights, topk_ids, row_idx = select_experts(
249249
hidden_states=x,
250250
router_logits=router_logits,
251251
top_k=top_k,
@@ -311,6 +311,7 @@ def apply(
311311
w2_scale_bias=layer.w2_scale_bias,
312312
topk_weights=topk_weights,
313313
topk_ids=topk_ids,
314+
row_idx=row_idx,
314315
top_k=top_k,
315316
expert_map=expert_map,
316317
ep_group=self.ep_group,

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def apply(
241241
assert router_logits.shape[
242242
1] == global_num_experts, "Number of global experts mismatch"
243243

244-
topk_weights, topk_ids = select_experts(
244+
topk_weights, topk_ids, _ = select_experts(
245245
hidden_states=x,
246246
router_logits=router_logits,
247247
top_k=top_k,

0 commit comments

Comments
 (0)