@@ -73,7 +73,6 @@ def test_initialization(self, dispatcher, config):
73
73
class TestTokenDispatcherWithMC2 (unittest .TestCase ):
74
74
75
75
def setUp (self ):
76
- # Mock get_mc2_group() 返回固定值
77
76
self .mc2_group = mock .MagicMock ()
78
77
self .mc2_group .device_group .return_value ._get_backend .return_value .get_hccl_comm_name .return_value = "hccl_123"
79
78
self .mc2_group .rank_in_group = 0
@@ -110,7 +109,6 @@ def setUp(self):
110
109
self .ascend_config_patch .start ()
111
110
112
111
kwargs = {"with_quant" : False , "top_k" : 8 , "num_experts" : 128 }
113
- # 初始化 TokenDispatcherWithMC2 实例
114
112
self .dispatcher = TokenDispatcherWithMC2 (** kwargs )
115
113
116
114
def tearDown (self ):
@@ -120,7 +118,6 @@ def tearDown(self):
120
118
self .ascend_config_patch .stop ()
121
119
122
120
def test_init (self ):
123
- """测试 __init__ 初始化行为"""
124
121
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
125
122
self .assertEqual (self .dispatcher .ep_rank_id , 0 )
126
123
self .assertEqual (self .dispatcher .ep_world_size , 8 )
@@ -131,7 +128,6 @@ def test_init(self):
131
128
self .assertTrue (self .dispatcher .a3_need_extra_args )
132
129
133
130
def test_get_permute_mc2_kwargs_without_quant (self ):
134
- """测试 get_permute_mc2_kwargs(无量化)"""
135
131
hidden_states = torch .randn (10 , 128 )
136
132
topk_ids = torch .randint (0 , 8 , (10 , 1 ))
137
133
topk_weights = torch .randn (10 , 1 )
@@ -144,7 +140,6 @@ def test_get_permute_mc2_kwargs_without_quant(self):
144
140
self .assertEqual (kwargs ["moe_expert_num" ], 8 )
145
141
146
142
def test_token_permutation_dispatch (self ):
147
- """测试 token_permutation(使用 dispatch)"""
148
143
hidden_states = torch .randn (10 , 128 )
149
144
topk_weights = torch .randn (10 , 1 )
150
145
topk_ids = torch .randint (0 , 8 , (10 , 1 ))
@@ -160,7 +155,6 @@ def test_token_permutation_dispatch(self):
160
155
self .assertEqual (output [0 ], 1 ) # group_list_type == 1
161
156
162
157
def test_token_permutation_with_shared_experts_and_quant (self ):
163
- """测试 token_permutation(有 shared_experts 且 with_quant=True)"""
164
158
self .shared_experts = mock .MagicMock ()
165
159
self .shared_experts .gate_up_proj .return_value = (torch .randn (10 , 128 ),
166
160
torch .tensor (1.0 ))
@@ -189,7 +183,6 @@ def test_token_permutation_with_shared_experts_and_quant(self):
189
183
self .topk_weights )
190
184
191
185
def test_get_unpermute_mc_kwargs_with_quant (self ):
192
- """测试 get_unpermute_mc_kwargs(with_quant=True)"""
193
186
self .dispatcher .with_quant = True
194
187
hidden_states = torch .randn (10 , 128 )
195
188
self .dispatcher .topk_ids = torch .randint (0 , 8 , (10 , 1 ))
@@ -198,6 +191,7 @@ def test_get_unpermute_mc_kwargs_with_quant(self):
198
191
self .dispatcher .ep_recv_counts = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ])
199
192
self .dispatcher .need_extra_args = True
200
193
self .dispatcher .enable_dispatch_v2 = True
194
+ self .dispatcher .output = torch .randint (0 , 8 , (10 , 1 ))
201
195
202
196
kwargs = self .dispatcher .get_unpermute_mc_kwargs (hidden_states )
203
197
self .assertIn ("tp_send_counts" , kwargs )
@@ -215,6 +209,7 @@ def test_token_unpermutation_with_shared_experts(self):
215
209
self .dispatcher .need_extra_args = True
216
210
self .dispatcher .enable_dispatch_v2 = True
217
211
self .dispatcher .swiglu_out_scale = torch .randint (0 , 8 , (10 , 1 ))
212
+ self .dispatcher .output = torch .randint (0 , 8 , (10 , 1 ))
218
213
self .hidden_states = torch .randn (10 , 128 )
219
214
220
215
with mock .patch ("torch_npu.npu_moe_distribute_combine_v2" ,
@@ -270,23 +265,6 @@ def tearDown(self):
270
265
self .patcher_moe_compute_expert_tokens .stop ()
271
266
self .patcher_moe_finalize_routing .stop ()
272
267
273
- def test_token_permutation_with_expert_map (self ):
274
- self .dispatcher .expert_map = torch .tensor ([0 , 1 , 2 , 3 ])
275
- hidden_states = torch .randn (3 , 128 )
276
- topk_weights = torch .tensor ([[0.7 , 0.3 ], [0.6 , 0.4 ], [0.5 , 0.5 ]])
277
- topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
278
-
279
- group_list_type , sorted_hidden_states , expert_tokens = self .dispatcher .token_permutation (
280
- hidden_states , topk_weights , topk_ids , self .dispatcher .expert_map )
281
-
282
- # Verify expert_map logic is used
283
- self .assertEqual (group_list_type , 0 )
284
- self .assertTrue (sorted_hidden_states .shape , (6 , 128 ))
285
-
286
- # Check if sorting and filtering were applied
287
- self .assertIsNotNone (self .dispatcher .sorted_token_indices )
288
- self .assertIsNotNone (self .dispatcher .sorted_weights )
289
-
290
268
def test_token_permutation_without_expert_map (self ):
291
269
hidden_states = torch .randn (3 , 128 )
292
270
topk_weights = torch .tensor ([[0.7 , 0.3 ], [0.6 , 0.4 ], [0.5 , 0.5 ]])
@@ -341,7 +319,11 @@ def test_token_unpermutation_without_expert_map(self):
341
319
self .dispatcher .with_quant = False
342
320
self .dispatcher .expanded_row_idx = torch .tensor ([0 , 1 , 1 , 1 , 1 , 1 ])
343
321
self .dispatcher .topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
322
+ self .dispatcher .sorted_token_indices = torch .tensor ([0 , 1 , 1 , 1 , 1 , 1 ])
323
+ self .dispatcher .sorted_weights = torch .tensor (
324
+ [0.5 , 0.5 , 0.5 , 0.5 , 0.5 , 0.5 ])
344
325
self .dispatcher .original_shape = (3 , 128 )
326
+ self .dispatcher .mask = torch .tensor ([0 , 1 , 1 , 0 ])
345
327
hidden_states = torch .randn (6 , 128 )
346
328
347
329
final_hidden_states = self .dispatcher .token_unpermutation (
0 commit comments