Skip to content

Commit 3f5eb78

Browse files
author
wangxiaoxin-sherie
committed
xx
1 parent 891282f commit 3f5eb78

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def test_initialization(self, dispatcher, config):
7373
class TestTokenDispatcherWithMC2(unittest.TestCase):
7474

7575
def setUp(self):
76-
# Mock get_mc2_group() 返回固定值
7776
self.mc2_group = mock.MagicMock()
7877
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
7978
self.mc2_group.rank_in_group = 0
@@ -110,7 +109,6 @@ def setUp(self):
110109
self.ascend_config_patch.start()
111110

112111
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
113-
# 初始化 TokenDispatcherWithMC2 实例
114112
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
115113

116114
def tearDown(self):
@@ -120,7 +118,6 @@ def tearDown(self):
120118
self.ascend_config_patch.stop()
121119

122120
def test_init(self):
123-
"""测试 __init__ 初始化行为"""
124121
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
125122
self.assertEqual(self.dispatcher.ep_rank_id, 0)
126123
self.assertEqual(self.dispatcher.ep_world_size, 8)
@@ -131,7 +128,6 @@ def test_init(self):
131128
self.assertTrue(self.dispatcher.a3_need_extra_args)
132129

133130
def test_get_permute_mc2_kwargs_without_quant(self):
134-
"""测试 get_permute_mc2_kwargs(无量化)"""
135131
hidden_states = torch.randn(10, 128)
136132
topk_ids = torch.randint(0, 8, (10, 1))
137133
topk_weights = torch.randn(10, 1)
@@ -144,7 +140,6 @@ def test_get_permute_mc2_kwargs_without_quant(self):
144140
self.assertEqual(kwargs["moe_expert_num"], 8)
145141

146142
def test_token_permutation_dispatch(self):
147-
"""测试 token_permutation(使用 dispatch)"""
148143
hidden_states = torch.randn(10, 128)
149144
topk_weights = torch.randn(10, 1)
150145
topk_ids = torch.randint(0, 8, (10, 1))
@@ -160,7 +155,6 @@ def test_token_permutation_dispatch(self):
160155
self.assertEqual(output[0], 1) # group_list_type == 1
161156

162157
def test_token_permutation_with_shared_experts_and_quant(self):
163-
"""测试 token_permutation(有 shared_experts 且 with_quant=True)"""
164158
self.shared_experts = mock.MagicMock()
165159
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
166160
torch.tensor(1.0))
@@ -189,7 +183,6 @@ def test_token_permutation_with_shared_experts_and_quant(self):
189183
self.topk_weights)
190184

191185
def test_get_unpermute_mc_kwargs_with_quant(self):
192-
"""测试 get_unpermute_mc_kwargs(with_quant=True)"""
193186
self.dispatcher.with_quant = True
194187
hidden_states = torch.randn(10, 128)
195188
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
@@ -198,6 +191,7 @@ def test_get_unpermute_mc_kwargs_with_quant(self):
198191
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
199192
self.dispatcher.need_extra_args = True
200193
self.dispatcher.enable_dispatch_v2 = True
194+
self.dispatcher.output = torch.randint(0, 8, (10, 1))
201195

202196
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
203197
self.assertIn("tp_send_counts", kwargs)
@@ -215,6 +209,7 @@ def test_token_unpermutation_with_shared_experts(self):
215209
self.dispatcher.need_extra_args = True
216210
self.dispatcher.enable_dispatch_v2 = True
217211
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
212+
self.dispatcher.output = torch.randint(0, 8, (10, 1))
218213
self.hidden_states = torch.randn(10, 128)
219214

220215
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
@@ -270,23 +265,6 @@ def tearDown(self):
270265
self.patcher_moe_compute_expert_tokens.stop()
271266
self.patcher_moe_finalize_routing.stop()
272267

273-
def test_token_permutation_with_expert_map(self):
274-
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
275-
hidden_states = torch.randn(3, 128)
276-
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
277-
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
278-
279-
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
280-
hidden_states, topk_weights, topk_ids, self.dispatcher.expert_map)
281-
282-
# Verify expert_map logic is used
283-
self.assertEqual(group_list_type, 0)
284-
self.assertTrue(sorted_hidden_states.shape, (6, 128))
285-
286-
# Check if sorting and filtering were applied
287-
self.assertIsNotNone(self.dispatcher.sorted_token_indices)
288-
self.assertIsNotNone(self.dispatcher.sorted_weights)
289-
290268
def test_token_permutation_without_expert_map(self):
291269
hidden_states = torch.randn(3, 128)
292270
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
@@ -341,7 +319,11 @@ def test_token_unpermutation_without_expert_map(self):
341319
self.dispatcher.with_quant = False
342320
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
343321
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
322+
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
323+
self.dispatcher.sorted_weights = torch.tensor(
324+
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
344325
self.dispatcher.original_shape = (3, 128)
326+
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
345327
hidden_states = torch.randn(6, 128)
346328

347329
final_hidden_states = self.dispatcher.token_unpermutation(

0 commit comments

Comments
 (0)