Skip to content

Commit 0c6ac41

Browse files
committed
[Qwen-moe] Remove the minor operation arange
Signed-off-by: s30076806 <[email protected]>
1 parent e14f2ef commit 0c6ac41

File tree

9 files changed

+45
-62
lines changed

9 files changed

+45
-62
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_select_experts(
148148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
149149
x)
150150

151-
topk_weights, topk_ids = select_experts(
151+
topk_weights, topk_ids, _ = select_experts(
152152
hidden_states=hidden_states,
153153
router_logits=router_logits,
154154
top_k=topk,

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
400400

401401
x = torch.randn(8, 2)
402402
router_logits = torch.randn(8, 2)
403-
topk_weights, topk_ids = select_experts(
403+
topk_weights, topk_ids, _ = select_experts(
404404
hidden_states=x,
405405
router_logits=router_logits,
406406
top_k=2,

tests/ut/quantization/test_w8a8.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def setUp(self):
719719
def test_softmax_scoring(self):
720720
"""Test softmax scoring function"""
721721

722-
weights, ids = select_experts(hidden_states=self.hidden_states,
722+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
723723
router_logits=self.router_logits,
724724
top_k=self.top_k,
725725
use_grouped_topk=False,
@@ -732,7 +732,7 @@ def test_softmax_scoring(self):
732732
def test_sigmoid_scoring(self):
733733
"""Test sigmoid scoring function"""
734734

735-
weights, ids = select_experts(hidden_states=self.hidden_states,
735+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
736736
router_logits=self.router_logits,
737737
top_k=self.top_k,
738738
use_grouped_topk=False,
@@ -760,7 +760,7 @@ def test_grouped_topk(self, mock_topk):
760760
self.top_k,
761761
dtype=torch.long))
762762

763-
weights, ids = select_experts(hidden_states=self.hidden_states,
763+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
764764
router_logits=self.router_logits,
765765
top_k=self.top_k,
766766
use_grouped_topk=True,
@@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
780780
self.num_experts)
781781

782782
e_score_correction_bias = torch.randn(self.num_experts)
783-
weights, ids = select_experts(
783+
weights, ids, _ = select_experts(
784784
hidden_states=self.hidden_states,
785785
router_logits=self.router_logits,
786786
top_k=self.top_k,
@@ -803,7 +803,7 @@ def test_custom_routing_function(self):
803803
self.top_k,
804804
dtype=torch.int32))
805805

806-
weights, ids = select_experts(
806+
weights, ids, _ = select_experts(
807807
hidden_states=self.hidden_states,
808808
router_logits=self.router_logits,
809809
top_k=self.top_k,
@@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
844844
self.top_k,
845845
dtype=torch.long))
846846

847-
weights, ids = select_experts(
847+
weights, ids, _ = select_experts(
848848
hidden_states=self.hidden_states,
849849
router_logits=self.router_logits,
850850
top_k=self.top_k,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def forward_oot(
6868
logical_to_physical_map: Optional[torch.Tensor] = None,
6969
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
7070

71-
topk_weights, topk_ids = select_experts(
71+
topk_weights, topk_ids, _ = select_experts(
7272
hidden_states=x,
7373
router_logits=router_logits,
7474
top_k=top_k,

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def fused_experts_with_all2all(
391391
w2: torch.Tensor,
392392
topk_weights: torch.Tensor,
393393
topk_ids: torch.Tensor,
394+
row_idx:torch.Tensor,
394395
top_k: int,
395396
expert_map: torch.Tensor = None,
396397
ep_group: GroupCoordinator = None,
@@ -401,17 +402,10 @@ def fused_experts_with_all2all(
401402

402403
num_tokens, _ = hidden_states.shape
403404
num_experts = w1.shape[0]
404-
device = hidden_states.device
405405

406406
if expert_map is not None:
407407
global_num_experts = len(expert_map)
408408
local_num_experts = global_num_experts // ep_group.world_size
409-
row_idx_len = num_tokens * top_k
410-
row_idx = (torch.arange(0,
411-
row_idx_len,
412-
dtype=torch.int32,
413-
device=device).view(top_k, -1).permute(
414-
1, 0).contiguous())
415409
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
416410
hidden_states,
417411
row_idx=row_idx,
@@ -445,12 +439,6 @@ def fused_experts_with_all2all(
445439

446440
hidden_states = hidden_states[sorted_idx]
447441
else:
448-
row_idx_len = num_tokens * top_k
449-
row_idx = torch.arange(0,
450-
row_idx_len,
451-
dtype=torch.int32,
452-
device=topk_weights.device).view(
453-
top_k, -1).permute(1, 0).contiguous()
454442
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
455443
hidden_states,
456444
row_idx=row_idx,
@@ -524,6 +512,7 @@ def fused_experts_with_all2all_buffer(
524512
w2: torch.Tensor,
525513
topk_weights: torch.Tensor,
526514
topk_ids: torch.Tensor,
515+
row_idx: torch.Tensor,
527516
top_k: int,
528517
max_model_len: int,
529518
global_batch_size: int,
@@ -535,14 +524,10 @@ def fused_experts_with_all2all_buffer(
535524
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
536525

537526
num_tokens, _ = hidden_states.shape
538-
device = hidden_states.device
539527

540528
global_num_experts = len(expert_map)
541529
local_num_experts = global_num_experts // ep_group.world_size
542530
row_idx_len = num_tokens * top_k
543-
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
544-
device=device).view(top_k,
545-
-1).permute(1, 0).contiguous())
546531
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
547532
hidden_states,
548533
row_idx=row_idx,
@@ -755,6 +740,7 @@ def fused_experts(
755740
w2: torch.Tensor,
756741
topk_weights: torch.Tensor,
757742
topk_ids: torch.Tensor,
743+
row_idx: torch.Tensor,
758744
top_k: int,
759745
expert_map: torch.Tensor = None,
760746
apply_router_weight_on_input: bool = False,
@@ -846,12 +832,6 @@ def fused_experts(
846832
# Rearrange hidden_states
847833
sorted_hidden_states = hidden_states[sorted_token_indices]
848834
else:
849-
row_idx_len = num_tokens * top_k
850-
row_idx = (torch.arange(0,
851-
row_idx_len,
852-
dtype=torch.int32,
853-
device=device).view(top_k, -1).permute(
854-
1, 0).contiguous())
855835
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
856836
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
857837
hidden_states,
@@ -975,7 +955,7 @@ def apply(
975955
**kwargs,
976956
) -> torch.Tensor:
977957

978-
topk_weights, topk_ids = select_experts(
958+
topk_weights, topk_ids, row_idx = select_experts(
979959
hidden_states=x,
980960
router_logits=router_logits,
981961
top_k=top_k,
@@ -1019,6 +999,7 @@ def apply(
1019999
w2=layer.w2_weight,
10201000
topk_weights=topk_weights,
10211001
topk_ids=topk_ids,
1002+
row_idx=row_idx,
10221003
top_k=top_k,
10231004
expert_map=expert_map)
10241005
elif MOE_ALL2ALL_BUFFER:
@@ -1028,6 +1009,7 @@ def apply(
10281009
w2=layer.w2_weight,
10291010
topk_weights=topk_weights,
10301011
topk_ids=topk_ids,
1012+
row_idx=row_idx,
10311013
top_k=top_k,
10321014
max_model_len=self.max_model_len,
10331015
global_batch_size=self.global_batch_size,
@@ -1049,6 +1031,7 @@ def apply(
10491031
w2=layer.w2_weight,
10501032
topk_weights=topk_weights,
10511033
topk_ids=topk_ids,
1034+
row_idx=row_idx,
10521035
top_k=top_k,
10531036
expert_map=expert_map,
10541037
ep_group=get_ep_group())

vllm_ascend/ops/layers/experts_selector.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
import torch
2020
import torch_npu
2121

22+
def return_row_idx(hidden_states, top_k):
23+
num_tokens, _ = hidden_states.shape
24+
row_idx_len = num_tokens * top_k
25+
row_idx = (torch.arange(0,
26+
row_idx_len,
27+
dtype=torch.int32,
28+
device=hidden_states.device).view(
29+
top_k, -1).permute(1, 0).contiguous())
30+
return row_idx
31+
2232

2333
def select_experts(hidden_states: torch.Tensor,
2434
router_logits: torch.Tensor,
@@ -56,7 +66,8 @@ def select_experts(hidden_states: torch.Tensor,
5666
topk_ids: selected expert IDs of shape (num_tokens, top_k).
5767
"""
5868

59-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
69+
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
70+
hidden_states=hidden_states,
6071
router_logits=router_logits,
6172
top_k=top_k,
6273
use_grouped_topk=use_grouped_topk,
@@ -83,7 +94,8 @@ def select_experts(hidden_states: torch.Tensor,
8394
e_score_correction_bias=e_score_correction_bias,
8495
global_num_experts=global_num_experts,
8596
)
86-
return topk_weights, topk_ids
97+
row_idx = return_row_idx(hidden_states, top_k)
98+
return topk_weights, topk_ids, row_idx
8799

88100

89101
def _native_grouped_topk(
@@ -156,6 +168,7 @@ def _select_expert_use_group_topk(
156168

157169

158170
def _select_experts_with_fusion_ops(
171+
hidden_states: torch.Tensor,
159172
router_logits: torch.Tensor,
160173
top_k: int,
161174
use_grouped_topk: bool,
@@ -168,7 +181,7 @@ def _select_experts_with_fusion_ops(
168181
global_num_experts: int = -1,
169182
is_unquantized: bool = False):
170183

171-
topk_weights, topk_ids = None, None
184+
topk_weights, topk_ids, row_idx = None, None, None
172185
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
173186
is_deepseek_v3_r1 = global_num_experts == 256
174187
if is_deepseek_v3_r1:
@@ -186,14 +199,14 @@ def _select_experts_with_fusion_ops(
186199
# y2_flag=False, # old api; should the third output be output
187200
routed_scaling_factor=1,
188201
eps=float(1e-20))
189-
202+
row_idx = return_row_idx(hidden_states, top_k)
190203
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
191-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
204+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
192205
x=router_logits, finished=None, k=top_k)
193206
topk_ids = topk_ids.to(torch.int32)
194207
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
195208

196-
return topk_weights, topk_ids
209+
return topk_weights, topk_ids, row_idx
197210

198211

199212
def _native_select_experts(

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def apply(
245245
1] == global_num_experts, "Number of global experts mismatch"
246246

247247
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
248-
topk_weights, topk_ids = select_experts(
248+
topk_weights, topk_ids, row_idx = select_experts(
249249
hidden_states=x,
250250
router_logits=router_logits,
251251
top_k=top_k,
@@ -311,6 +311,7 @@ def apply(
311311
w2_scale_bias=layer.w2_scale_bias,
312312
topk_weights=topk_weights,
313313
topk_ids=topk_ids,
314+
row_idx=row_idx,
314315
top_k=top_k,
315316
expert_map=expert_map,
316317
ep_group=self.ep_group,

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def apply(
241241
assert router_logits.shape[
242242
1] == global_num_experts, "Number of global experts mismatch"
243243

244-
topk_weights, topk_ids = select_experts(
244+
topk_weights, topk_ids, _ = select_experts(
245245
hidden_states=x,
246246
router_logits=router_logits,
247247
top_k=top_k,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,8 @@ def fused_experts_with_mc2(
365365
return hidden_states, shared_output
366366

367367

368-
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
368+
def init_routing_quant(hidden_states, top_k, topk_ids, row_idx, global_num_experts):
369369
num_tokens, _ = hidden_states.shape
370-
row_idx_len = num_tokens * top_k
371-
row_idx = (torch.arange(0,
372-
row_idx_len,
373-
dtype=torch.int32,
374-
device=hidden_states.device).view(
375-
top_k, -1).permute(1, 0).contiguous())
376370
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
377371
hidden_states,
378372
row_idx=row_idx,
@@ -398,6 +392,7 @@ def fused_experts_with_all2all(
398392
w2_scale: torch.Tensor,
399393
topk_weights: torch.Tensor,
400394
topk_ids: torch.Tensor,
395+
row_idx: torch.Tensor,
401396
top_k: int,
402397
expert_map: torch.Tensor = None,
403398
ep_group: GroupCoordinator = None,
@@ -431,7 +426,7 @@ def fused_experts_with_all2all(
431426
)
432427
else:
433428
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
434-
hidden_states, top_k, topk_ids, global_num_experts)
429+
hidden_states, top_k, topk_ids, row_idx, global_num_experts)
435430

436431
gather_sizes = global_expert_tokens.new_empty(
437432
global_expert_tokens.shape[0])
@@ -463,12 +458,6 @@ def fused_experts_with_all2all(
463458
expert_tokens = expert_tokens.to(torch.int64)
464459
group_list_type = 1
465460
else:
466-
row_idx_len = num_tokens * top_k
467-
row_idx = torch.arange(0,
468-
row_idx_len,
469-
dtype=torch.int32,
470-
device=topk_weights.device).view(
471-
top_k, -1).permute(1, 0).contiguous()
472461
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
473462
hidden_states,
474463
row_idx=row_idx,
@@ -627,6 +616,7 @@ def fused_experts(hidden_states: torch.Tensor,
627616
w2_scale: torch.Tensor,
628617
topk_weights: torch.Tensor,
629618
topk_ids: torch.Tensor,
619+
row_idx: torch.Tensor,
630620
top_k: int,
631621
expert_map: torch.Tensor = None):
632622
original_shape = hidden_states.shape
@@ -677,12 +667,6 @@ def fused_experts(hidden_states: torch.Tensor,
677667
hidden_states = hidden_states[sorted_token_indices]
678668
group_list_type = 1
679669
else:
680-
row_idx_len = num_tokens * top_k
681-
row_idx = torch.arange(0,
682-
row_idx_len,
683-
dtype=torch.int32,
684-
device=topk_weights.device).view(
685-
top_k, -1).permute(1, 0).contiguous()
686670
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
687671
hidden_states,
688672
row_idx=row_idx,
@@ -903,7 +887,7 @@ def apply(
903887
assert router_logits.shape[
904888
1] == global_num_experts, "Number of global experts mismatch"
905889

906-
topk_weights, topk_ids = select_experts(
890+
topk_weights, topk_ids, row_idx = select_experts(
907891
hidden_states=x,
908892
router_logits=router_logits,
909893
top_k=top_k,
@@ -973,6 +957,7 @@ def apply(
973957
w2_scale=layer.w2_weight_scale,
974958
topk_weights=topk_weights,
975959
topk_ids=topk_ids,
960+
row_idx=row_idx,
976961
top_k=top_k,
977962
expert_map=expert_map)
978963
else:
@@ -988,6 +973,7 @@ def apply(
988973
w2_scale=layer.w2_weight_scale,
989974
topk_weights=topk_weights,
990975
topk_ids=topk_ids,
976+
row_idx=row_idx,
991977
top_k=top_k,
992978
expert_map=expert_map,
993979
ep_group=self.ep_group,

0 commit comments

Comments
 (0)