Skip to content

Commit 8e11b17

Browse files
committed
remove row_idx
Signed-off-by: CaranLic <[email protected]>
1 parent 688bd57 commit 8e11b17

File tree

12 files changed

+19
-57
lines changed

12 files changed

+19
-57
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def test_token_dispatcher_with_all_gather(
137137
hidden_states=a,
138138
topk_weights=topk_weights,
139139
topk_ids=topk_ids,
140-
row_idx=row_idx,
141140
expert_map=expert_map,
142141
apply_router_weight_on_input=apply_router_weight_on_input)
143142

@@ -220,7 +219,6 @@ def test_token_dispatcher_with_all_gather_quant(
220219
hidden_states=a,
221220
topk_weights=topk_weights,
222221
topk_ids=topk_ids,
223-
row_idx=row_idx,
224222
expert_map=expert_map,
225223
apply_router_weight_on_input=apply_router_weight_on_input,
226224
with_quant=True)
@@ -295,7 +293,7 @@ def test_select_experts(
295293
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
296294
x)
297295

298-
topk_weights, topk_ids, row_idx = select_experts(
296+
topk_weights, topk_ids = select_experts(
299297
hidden_states=hidden_states,
300298
router_logits=router_logits,
301299
top_k=topk,
@@ -316,7 +314,6 @@ def test_select_experts(
316314
assert topk_weights.shape == (m, topk)
317315
assert topk_ids.shape == (m, topk)
318316
assert topk_ids.dtype == torch.int32
319-
assert row_idx.shape == (m, topk)
320317

321318
gc.collect()
322319
torch.npu.empty_cache()

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
454454

455455
x = torch.randn(8, 2)
456456
router_logits = torch.randn(8, 2)
457-
topk_weights, topk_ids, _ = select_experts(
457+
topk_weights, topk_ids = select_experts(
458458
hidden_states=x,
459459
router_logits=router_logits,
460460
top_k=2,

tests/ut/ops/test_moe_comm_method.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
216216
w2=w2,
217217
topk_weights=topk_weights,
218218
topk_ids=topk_ids,
219-
row_idx=row_idx,
220219
activation="silu")
221220

222221
# Verify result shape

tests/ut/ops/test_token_dispatcher.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def setUp(self):
5858

5959
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
6060
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
61-
self.row_idx = torch.arange(10, dtype=torch.int32)
6261

6362
def tearDown(self):
6463
self.mc2_group_patch.stop()
@@ -95,7 +94,7 @@ def test_token_permutation_dispatch(self):
9594
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
9695
output = self.dispatcher.token_dispatch(hidden_states,
9796
topk_weights, topk_ids,
98-
self.row_idx, expert_map)
97+
expert_map)
9998
mock_dispatch.assert_called_once()
10099
self.assertEqual(output["group_list_type"],
101100
0) # group_list_type == 0
@@ -116,7 +115,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self):
116115
self.dispatcher.token_dispatch(self.hidden_states,
117116
self.topk_weights,
118117
torch.randint(0, 8, (10, 1)),
119-
self.row_idx,
120118
torch.tensor(
121119
[0, 1, 2, 3, 4, 5, 6, 7]),
122120
shared_experts=self.shared_experts)
@@ -180,7 +178,6 @@ def setUp(self):
180178
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
181179
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
182180
torch.tensor([0, 1, 0, 1, 0, 1]))
183-
self.row_idx = torch.arange(10, dtype=torch.int32)
184181
self.patcher_npu_moe_token_unpermute = patch(
185182
'torch_npu.npu_moe_token_unpermute')
186183
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
@@ -197,7 +194,7 @@ def test_token_dispatch_without_expert_map(self):
197194
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
198195

199196
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
200-
topk_ids, self.row_idx, None)
197+
topk_ids, None)
201198

202199
# Verify npu_moe_init_routing is called
203200
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -212,7 +209,7 @@ def test_token_dispatch_with_expert_map(self):
212209
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
213210

214211
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
215-
topk_ids, self.row_idx, None)
212+
topk_ids, None)
216213

217214
# Verify npu_moe_init_routing is called
218215
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -236,7 +233,7 @@ def test_token_dispatch_without_quant(self):
236233

237234
results = self.dispatcher_quant.token_dispatch(hidden_states,
238235
topk_weights, topk_ids,
239-
self.row_idx, None)
236+
None)
240237

241238
self.assertEqual(results["group_list_type"], 1)
242239

@@ -257,7 +254,6 @@ def test_token_dispatch_with_quant(self):
257254
results = self.dispatcher_quant.token_dispatch(hidden_states,
258255
topk_weights,
259256
topk_ids,
260-
self.row_idx,
261257
None,
262258
with_quant=True)
263259

@@ -399,7 +395,6 @@ def setUp(self):
399395
num_experts=4,
400396
num_local_experts=2,
401397
with_quant=False)
402-
self.row_idx = torch.arange(10, dtype=torch.int32)
403398

404399
def test_token_dispatch(self):
405400
hidden_states = torch.randn(8, 16)
@@ -414,7 +409,6 @@ def test_token_dispatch(self):
414409
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
415410
topk_weights=topk_weights,
416411
topk_ids=topk_ids,
417-
row_idx=self.row_idx,
418412
expert_map=expert_map)
419413

420414
self.assertIsNotNone(result["hidden_states"])
@@ -461,7 +455,6 @@ def test_token_dispatch_with_quant(self):
461455
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
462456
topk_weights=topk_weights,
463457
topk_ids=topk_ids,
464-
row_idx=self.row_idx,
465458
expert_map=expert_map,
466459
with_quant=True)
467460

@@ -490,7 +483,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
490483
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
491484
topk_weights=topk_weights,
492485
topk_ids=topk_ids,
493-
row_idx=self.row_idx,
494486
expert_map=expert_map,
495487
with_quant=True)
496488

@@ -513,7 +505,6 @@ def test_token_dispatch_with_log2phy(self):
513505
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
514506
topk_weights=topk_weights,
515507
topk_ids=topk_ids,
516-
row_idx=self.row_idx,
517508
expert_map=expert_map,
518509
log2phy=log2phy)
519510

tests/ut/quantization/test_w8a8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def test_softmax_scoring(self, mock_topk):
730730
-1).permute(1,
731731
0).contiguous())
732732

733-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
733+
weights, ids = select_experts(hidden_states=self.hidden_states,
734734
router_logits=self.router_logits,
735735
top_k=self.top_k,
736736
use_grouped_topk=False,
@@ -743,7 +743,7 @@ def test_softmax_scoring(self, mock_topk):
743743
def test_sigmoid_scoring(self):
744744
"""Test sigmoid scoring function"""
745745

746-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
746+
weights, ids = select_experts(hidden_states=self.hidden_states,
747747
router_logits=self.router_logits,
748748
top_k=self.top_k,
749749
use_grouped_topk=False,
@@ -771,7 +771,7 @@ def test_grouped_topk(self, mock_topk):
771771
self.top_k,
772772
dtype=torch.long))
773773

774-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
774+
weights, ids = select_experts(hidden_states=self.hidden_states,
775775
router_logits=self.router_logits,
776776
top_k=self.top_k,
777777
use_grouped_topk=True,
@@ -791,7 +791,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
791791
self.num_experts)
792792

793793
e_score_correction_bias = torch.randn(self.num_experts)
794-
weights, ids, _ = select_experts(
794+
weights, ids = select_experts(
795795
hidden_states=self.hidden_states,
796796
router_logits=self.router_logits,
797797
top_k=self.top_k,
@@ -814,7 +814,7 @@ def test_custom_routing_function(self):
814814
self.top_k,
815815
dtype=torch.int32))
816816

817-
weights, ids, _ = select_experts(
817+
weights, ids = select_experts(
818818
hidden_states=self.hidden_states,
819819
router_logits=self.router_logits,
820820
top_k=self.top_k,
@@ -841,7 +841,7 @@ def test_renormalize(self, mock_topk):
841841
-1).permute(1,
842842
0).contiguous())
843843

844-
weights, ids, _ = select_experts(
844+
weights, ids = select_experts(
845845
hidden_states=self.hidden_states,
846846
router_logits=self.router_logits,
847847
top_k=self.top_k,
@@ -867,7 +867,7 @@ def test_output_dtypes(self, mock_topk):
867867
-1).permute(1,
868868
0).contiguous())
869869

870-
weights, ids, _ = select_experts(
870+
weights, ids = select_experts(
871871
hidden_states=self.hidden_states,
872872
router_logits=self.router_logits,
873873
top_k=self.top_k,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def forward_oot(
8181
logical_to_physical_map: Optional[torch.Tensor] = None,
8282
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
8383

84-
topk_weights, topk_ids, row_idx = select_experts(
84+
topk_weights, topk_ids = select_experts(
8585
hidden_states=x,
8686
router_logits=router_logits,
8787
top_k=top_k,
@@ -101,7 +101,6 @@ def forward_oot(
101101
w2=layer.w2_weight,
102102
topk_weights=topk_weights,
103103
topk_ids=topk_ids,
104-
row_idx=row_idx,
105104
global_num_experts=global_num_experts,
106105
expert_map=expert_map)
107106

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def apply(
105105
**kwargs,
106106
) -> torch.Tensor:
107107

108-
topk_weights, topk_ids, row_idx = select_experts(
108+
topk_weights, topk_ids = select_experts(
109109
hidden_states=x,
110110
router_logits=router_logits,
111111
top_k=top_k,
@@ -132,7 +132,6 @@ def apply(
132132
w2=layer.w2_weight,
133133
topk_weights=topk_weights,
134134
topk_ids=topk_ids,
135-
row_idx=row_idx,
136135
global_num_experts=global_num_experts,
137136
expert_map=expert_map,
138137
shared_experts=shared_experts,

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@
2020
import torch_npu
2121

2222

23-
def return_row_idx(hidden_states, top_k):
24-
num_tokens = hidden_states.shape[0]
25-
row_idx_len = num_tokens * top_k
26-
row_idx = (torch.arange(0,
27-
row_idx_len,
28-
dtype=torch.int32,
29-
device=hidden_states.device).view(
30-
top_k, -1).permute(1, 0).contiguous())
31-
return row_idx
32-
33-
3423
def select_experts(hidden_states: torch.Tensor,
3524
router_logits: torch.Tensor,
3625
top_k: int,
@@ -66,7 +55,7 @@ def select_experts(hidden_states: torch.Tensor,
6655
topk_ids: selected expert IDs of shape (num_tokens, top_k).
6756
"""
6857

69-
topk_weights, topk_ids, row_idx = None, None, None
58+
topk_weights, topk_ids = None, None
7059
if custom_routing_function is None:
7160
topk_weights, topk_ids = _select_experts_with_fusion_ops(
7261
hidden_states=hidden_states,
@@ -94,9 +83,7 @@ def select_experts(hidden_states: torch.Tensor,
9483
e_score_correction_bias=e_score_correction_bias,
9584
global_num_experts=global_num_experts,
9685
)
97-
if row_idx is None:
98-
row_idx = return_row_idx(hidden_states, top_k)
99-
return topk_weights, topk_ids, row_idx
86+
return topk_weights, topk_ids
10087

10188

10289
def _native_grouped_topk(

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def fused_experts(
8989
w2: torch.Tensor,
9090
topk_weights: torch.Tensor,
9191
topk_ids: torch.Tensor,
92-
row_idx: torch.Tensor,
9392
activation: str = "silu",
9493
apply_router_weight_on_input: bool = False,
9594
use_int8_w8a8: bool = False,
@@ -123,7 +122,6 @@ def fused_experts(
123122
hidden_states=hidden_states,
124123
topk_weights=topk_weights,
125124
topk_ids=topk_ids,
126-
row_idx=row_idx,
127125
expert_map=expert_map,
128126
log2phy=log2phy,
129127
global_redundant_expert_num=global_redundant_expert_num,

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def token_dispatch(self,
6161
hidden_states: torch.Tensor,
6262
topk_weights: torch.Tensor,
6363
topk_ids: torch.Tensor,
64-
row_idx: torch.Tensor,
6564
expert_map: Optional[torch.Tensor] = None,
6665
log2phy: Optional[torch.Tensor] = None,
6766
global_redundant_expert_num: int = 0,
@@ -161,7 +160,6 @@ def token_dispatch(self,
161160
hidden_states: torch.Tensor,
162161
topk_weights: torch.Tensor,
163162
topk_ids: torch.Tensor,
164-
row_idx: torch.Tensor,
165163
expert_map: Optional[torch.Tensor] = None,
166164
log2phy: Optional[torch.Tensor] = None,
167165
global_redundant_expert_num: int = 0,
@@ -318,7 +316,6 @@ def token_dispatch(self,
318316
hidden_states: torch.Tensor,
319317
topk_weights: torch.Tensor,
320318
topk_ids: torch.Tensor,
321-
row_idx: torch.Tensor,
322319
expert_map: Optional[torch.Tensor] = None,
323320
log2phy: Optional[torch.Tensor] = None,
324321
global_redundant_expert_num: int = 0,
@@ -410,7 +407,6 @@ def token_dispatch(self,
410407
hidden_states: torch.Tensor,
411408
topk_weights: torch.Tensor,
412409
topk_ids: torch.Tensor,
413-
row_idx: torch.Tensor,
414410
expert_map: Optional[torch.Tensor] = None,
415411
log2phy: Optional[torch.Tensor] = None,
416412
global_redundant_expert_num: int = 0,
@@ -511,7 +507,6 @@ def token_dispatch(self,
511507
hidden_states: torch.Tensor,
512508
topk_weights: torch.Tensor,
513509
topk_ids: torch.Tensor,
514-
row_idx: torch.Tensor,
515510
expert_map: Optional[torch.Tensor] = None,
516511
log2phy: Optional[torch.Tensor] = None,
517512
global_redundant_expert_num: int = 0,

0 commit comments

Comments
 (0)