Skip to content

Commit 74b23cb

Browse files
committed
replace arange and permute with the third output of npu_moe_gating_top_k_softmax
Signed-off-by: huangxialu <[email protected]>
1 parent 3fc31ee commit 74b23cb

File tree

9 files changed

+32
-26
lines changed

9 files changed

+32
-26
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_select_experts(
148148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
149149
x)
150150

151-
topk_weights, topk_ids = select_experts(
151+
topk_weights, topk_ids, row_idx = select_experts(
152152
hidden_states=hidden_states,
153153
router_logits=router_logits,
154154
top_k=topk,
@@ -169,6 +169,8 @@ def test_select_experts(
169169
assert topk_weights.shape == (m, topk)
170170
assert topk_ids.shape == (m, topk)
171171
assert topk_ids.dtype == torch.int32
172+
assert row_idx.shape == (m, topk)
173+
assert row_idx.dtype == torch.int32
172174

173175

174176
@pytest.mark.parametrize("device", DEVICE)

tests/ut/ops/test_fused_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
400400

401401
x = torch.randn(8, 2)
402402
router_logits = torch.randn(8, 2)
403-
topk_weights, topk_ids = select_experts(
403+
topk_weights, topk_ids, row_idx = select_experts(
404404
hidden_states=x,
405405
router_logits=router_logits,
406406
top_k=2,
@@ -415,3 +415,4 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
415415

416416
assert topk_weights.shape == (8, 2)
417417
assert topk_ids.shape == (8, 2)
418+
assert row_idx.shape == (8, 2)

tests/ut/quantization/test_w8a8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def setUp(self):
719719
def test_softmax_scoring(self):
720720
"""Test softmax scoring function"""
721721

722-
weights, ids = select_experts(hidden_states=self.hidden_states,
722+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
723723
router_logits=self.router_logits,
724724
top_k=self.top_k,
725725
use_grouped_topk=False,
@@ -732,7 +732,7 @@ def test_softmax_scoring(self):
732732
def test_sigmoid_scoring(self):
733733
"""Test sigmoid scoring function"""
734734

735-
weights, ids = select_experts(hidden_states=self.hidden_states,
735+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
736736
router_logits=self.router_logits,
737737
top_k=self.top_k,
738738
use_grouped_topk=False,
@@ -760,7 +760,7 @@ def test_grouped_topk(self, mock_topk):
760760
self.top_k,
761761
dtype=torch.long))
762762

763-
weights, ids = select_experts(hidden_states=self.hidden_states,
763+
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
764764
router_logits=self.router_logits,
765765
top_k=self.top_k,
766766
use_grouped_topk=True,
@@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
780780
self.num_experts)
781781

782782
e_score_correction_bias = torch.randn(self.num_experts)
783-
weights, ids = select_experts(
783+
weights, ids, _ = select_experts(
784784
hidden_states=self.hidden_states,
785785
router_logits=self.router_logits,
786786
top_k=self.top_k,
@@ -803,7 +803,7 @@ def test_custom_routing_function(self):
803803
self.top_k,
804804
dtype=torch.int32))
805805

806-
weights, ids = select_experts(
806+
weights, ids, _ = select_experts(
807807
hidden_states=self.hidden_states,
808808
router_logits=self.router_logits,
809809
top_k=self.top_k,
@@ -824,7 +824,7 @@ def test_renormalize(self, mock_topk):
824824
self.top_k,
825825
dtype=torch.long))
826826

827-
weights, _ = select_experts(
827+
weights, _, _ = select_experts(
828828
hidden_states=self.hidden_states,
829829
router_logits=self.router_logits,
830830
top_k=self.top_k,
@@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk):
844844
self.top_k,
845845
dtype=torch.long))
846846

847-
weights, ids = select_experts(
847+
weights, ids, _ = select_experts(
848848
hidden_states=self.hidden_states,
849849
router_logits=self.router_logits,
850850
top_k=self.top_k,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def forward_oot(
6868
logical_to_physical_map: Optional[torch.Tensor] = None,
6969
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
7070

71-
topk_weights, topk_ids = select_experts(
71+
topk_weights, topk_ids, _ = select_experts(
7272
hidden_states=x,
7373
router_logits=router_logits,
7474
top_k=top_k,

vllm_ascend/ops/fused_moe.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ def fused_experts(
759759
expert_map: torch.Tensor = None,
760760
apply_router_weight_on_input: bool = False,
761761
max_num_tokens: Optional[int] = None,
762+
row_idx: torch.Tensor = None,
762763
) -> torch.Tensor:
763764
"""
764765
Fused experts with top-k routing.
@@ -846,12 +847,13 @@ def fused_experts(
846847
# Rearrange hidden_states
847848
sorted_hidden_states = hidden_states[sorted_token_indices]
848849
else:
849-
row_idx_len = num_tokens * top_k
850-
row_idx = (torch.arange(0,
851-
row_idx_len,
852-
dtype=torch.int32,
853-
device=device).view(top_k, -1).permute(
854-
1, 0).contiguous())
850+
if row_idx is None:
851+
row_idx_len = num_tokens * top_k
852+
row_idx = (torch.arange(0,
853+
row_idx_len,
854+
dtype=torch.int32,
855+
device=device).view(top_k, -1).permute(
856+
1, 0).contiguous())
855857
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
856858
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
857859
hidden_states,
@@ -975,7 +977,7 @@ def apply(
975977
**kwargs,
976978
) -> torch.Tensor:
977979

978-
topk_weights, topk_ids = select_experts(
980+
topk_weights, topk_ids, row_idx = select_experts(
979981
hidden_states=x,
980982
router_logits=router_logits,
981983
top_k=top_k,
@@ -1020,7 +1022,8 @@ def apply(
10201022
topk_weights=topk_weights,
10211023
topk_ids=topk_ids,
10221024
top_k=top_k,
1023-
expert_map=expert_map)
1025+
expert_map=expert_map,
1026+
row_idx=row_idx)
10241027
elif MOE_ALL2ALL_BUFFER:
10251028
return fused_experts_with_all2all_buffer(
10261029
hidden_states=x,

vllm_ascend/ops/layers/experts_selector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def select_experts(hidden_states: torch.Tensor,
5656
topk_ids: selected expert IDs of shape (num_tokens, top_k).
5757
"""
5858

59-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
59+
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
6060
router_logits=router_logits,
6161
top_k=top_k,
6262
use_grouped_topk=use_grouped_topk,
@@ -83,7 +83,7 @@ def select_experts(hidden_states: torch.Tensor,
8383
e_score_correction_bias=e_score_correction_bias,
8484
global_num_experts=global_num_experts,
8585
)
86-
return topk_weights, topk_ids
86+
return topk_weights, topk_ids, row_idx
8787

8888

8989
def _native_grouped_topk(
@@ -168,7 +168,7 @@ def _select_experts_with_fusion_ops(
168168
global_num_experts: int = -1,
169169
is_unquantized: bool = False):
170170

171-
topk_weights, topk_ids = None, None
171+
topk_weights, topk_ids, row_idx = None, None, None
172172
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
173173
is_deepseek_v3_r1 = global_num_experts == 256
174174
if is_deepseek_v3_r1:
@@ -188,12 +188,12 @@ def _select_experts_with_fusion_ops(
188188
eps=float(1e-20))
189189

190190
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
191-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
191+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
192192
x=router_logits, finished=None, k=top_k)
193193
topk_ids = topk_ids.to(torch.int32)
194194
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
195195

196-
return topk_weights, topk_ids
196+
return topk_weights, topk_ids, row_idx
197197

198198

199199
def _native_select_experts(

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def apply(
245245
1] == global_num_experts, "Number of global experts mismatch"
246246

247247
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
248-
topk_weights, topk_ids = select_experts(
248+
topk_weights, topk_ids, _ = select_experts(
249249
hidden_states=x,
250250
router_logits=router_logits,
251251
top_k=top_k,

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def apply(
241241
assert router_logits.shape[
242242
1] == global_num_experts, "Number of global experts mismatch"
243243

244-
topk_weights, topk_ids = select_experts(
244+
topk_weights, topk_ids, _ = select_experts(
245245
hidden_states=x,
246246
router_logits=router_logits,
247247
top_k=top_k,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def apply(
903903
assert router_logits.shape[
904904
1] == global_num_experts, "Number of global experts mismatch"
905905

906-
topk_weights, topk_ids = select_experts(
906+
topk_weights, topk_ids, _ = select_experts(
907907
hidden_states=x,
908908
router_logits=router_logits,
909909
top_k=top_k,

0 commit comments

Comments
 (0)