@@ -58,7 +58,6 @@ def setUp(self):
58
58
59
59
kwargs = {"with_quant" : False , "top_k" : 8 , "num_experts" : 128 }
60
60
self .dispatcher = TokenDispatcherWithMC2 (** kwargs )
61
- self .row_idx = torch .arange (10 , dtype = torch .int32 )
62
61
63
62
def tearDown (self ):
64
63
self .mc2_group_patch .stop ()
@@ -95,7 +94,7 @@ def test_token_permutation_dispatch(self):
95
94
return_value = (torch .randn (10 , 128 ), ) * 5 ) as mock_dispatch :
96
95
output = self .dispatcher .token_dispatch (hidden_states ,
97
96
topk_weights , topk_ids ,
98
- self . row_idx , expert_map )
97
+ expert_map )
99
98
mock_dispatch .assert_called_once ()
100
99
self .assertEqual (output ["group_list_type" ],
101
100
0 ) # group_list_type == 0
@@ -116,7 +115,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self):
116
115
self .dispatcher .token_dispatch (self .hidden_states ,
117
116
self .topk_weights ,
118
117
torch .randint (0 , 8 , (10 , 1 )),
119
- self .row_idx ,
120
118
torch .tensor (
121
119
[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
122
120
shared_experts = self .shared_experts )
@@ -180,7 +178,6 @@ def setUp(self):
180
178
torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ]), # expanded_row_idx
181
179
torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ]), # expanded_expert_idx
182
180
torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ]))
183
- self .row_idx = torch .arange (10 , dtype = torch .int32 )
184
181
self .patcher_npu_moe_token_unpermute = patch (
185
182
'torch_npu.npu_moe_token_unpermute' )
186
183
self .mock_npu_moe_token_unpermute = self .patcher_npu_moe_token_unpermute .start (
@@ -197,7 +194,7 @@ def test_token_dispatch_without_expert_map(self):
197
194
topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
198
195
199
196
results = self .dispatcher .token_dispatch (hidden_states , topk_weights ,
200
- topk_ids , self . row_idx , None )
197
+ topk_ids , None )
201
198
202
199
# Verify npu_moe_init_routing is called
203
200
self .mock_npu_moe_init_routing_v2 .assert_called_once ()
@@ -212,7 +209,7 @@ def test_token_dispatch_with_expert_map(self):
212
209
topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
213
210
214
211
results = self .dispatcher .token_dispatch (hidden_states , topk_weights ,
215
- topk_ids , self . row_idx , None )
212
+ topk_ids , None )
216
213
217
214
# Verify npu_moe_init_routing is called
218
215
self .mock_npu_moe_init_routing_v2 .assert_called_once ()
@@ -236,7 +233,7 @@ def test_token_dispatch_without_quant(self):
236
233
237
234
results = self .dispatcher_quant .token_dispatch (hidden_states ,
238
235
topk_weights , topk_ids ,
239
- self . row_idx , None )
236
+ None )
240
237
241
238
self .assertEqual (results ["group_list_type" ], 1 )
242
239
@@ -257,7 +254,6 @@ def test_token_dispatch_with_quant(self):
257
254
results = self .dispatcher_quant .token_dispatch (hidden_states ,
258
255
topk_weights ,
259
256
topk_ids ,
260
- self .row_idx ,
261
257
None ,
262
258
with_quant = True )
263
259
@@ -399,7 +395,6 @@ def setUp(self):
399
395
num_experts = 4 ,
400
396
num_local_experts = 2 ,
401
397
with_quant = False )
402
- self .row_idx = torch .arange (10 , dtype = torch .int32 )
403
398
404
399
def test_token_dispatch (self ):
405
400
hidden_states = torch .randn (8 , 16 )
@@ -414,7 +409,6 @@ def test_token_dispatch(self):
414
409
result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
415
410
topk_weights = topk_weights ,
416
411
topk_ids = topk_ids ,
417
- row_idx = self .row_idx ,
418
412
expert_map = expert_map )
419
413
420
414
self .assertIsNotNone (result ["hidden_states" ])
@@ -461,7 +455,6 @@ def test_token_dispatch_with_quant(self):
461
455
result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
462
456
topk_weights = topk_weights ,
463
457
topk_ids = topk_ids ,
464
- row_idx = self .row_idx ,
465
458
expert_map = expert_map ,
466
459
with_quant = True )
467
460
@@ -490,7 +483,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
490
483
result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
491
484
topk_weights = topk_weights ,
492
485
topk_ids = topk_ids ,
493
- row_idx = self .row_idx ,
494
486
expert_map = expert_map ,
495
487
with_quant = True )
496
488
@@ -513,7 +505,6 @@ def test_token_dispatch_with_log2phy(self):
513
505
result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
514
506
topk_weights = topk_weights ,
515
507
topk_ids = topk_ids ,
516
- row_idx = self .row_idx ,
517
508
expert_map = expert_map ,
518
509
log2phy = log2phy )
519
510
0 commit comments