Skip to content

Commit 6aff3e9

Browse files
author
wangxiaoxin-sherie
committed
refactor mc2/allgather tokendispatch.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 3e65c40 commit 6aff3e9

File tree

2 files changed

+820
-2
lines changed

2 files changed

+820
-2
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 310 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
# limitations under the License.
1616
# This file is a part of the vllm-ascend project.
1717

18+
import unittest
19+
from unittest import mock
20+
1821
import pytest
22+
import torch
1923
from pytest_mock import MockerFixture
2024

2125
from tests.ut.base import PytestBase
2226
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
23-
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
27+
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
28+
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
2429
from vllm_ascend.utils import adapt_patch # noqa E402
2530

2631

@@ -63,3 +68,307 @@ def test_initialization(self, dispatcher, config):
6368
assert dispatcher.ep_rank == 0
6469
assert dispatcher.ep_size == 2
6570
assert dispatcher.overlap_stream is not None
71+
72+
73+
class TestTokenDispatcherWithMC2(unittest.TestCase):
74+
75+
def setUp(self):
76+
# Mock get_mc2_group() 返回固定值
77+
self.mc2_group = mock.MagicMock()
78+
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
79+
self.mc2_group.rank_in_group = 0
80+
self.mc2_group.world_size = 8
81+
self.mc2_group_patch = mock.patch(
82+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
83+
return_value=self.mc2_group)
84+
self.mc2_group_patch.start()
85+
86+
self.rank_group_patch = mock.patch("torch.distributed.get_rank",
87+
return_value=0)
88+
self.rank_group_patch.start()
89+
90+
# Mock get_forward_context().mc2_mask
91+
self.forward_context = mock.MagicMock()
92+
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
93+
self.forward_context_patch = mock.patch(
94+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
95+
return_value=self.forward_context)
96+
self.forward_context_patch.start()
97+
98+
# Mock get_ascend_soc_version()
99+
self.ascend_soc_version_patch = mock.patch(
100+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
101+
return_value=AscendSocVersion.A3)
102+
self.ascend_soc_version_patch.start()
103+
104+
# Mock get_ascend_config()
105+
self.ascend_config = mock.MagicMock()
106+
self.ascend_config.torchair_graph_config.enabled = False
107+
self.ascend_config_patch = mock.patch(
108+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
109+
return_value=self.ascend_config)
110+
self.ascend_config_patch.start()
111+
112+
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
113+
# 初始化 TokenDispatcherWithMC2 实例
114+
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
115+
116+
def tearDown(self):
117+
self.mc2_group_patch.stop()
118+
self.forward_context_patch.stop()
119+
self.ascend_soc_version_patch.stop()
120+
self.ascend_config_patch.stop()
121+
122+
def test_init(self):
123+
"""测试 __init__ 初始化行为"""
124+
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
125+
self.assertEqual(self.dispatcher.ep_rank_id, 0)
126+
self.assertEqual(self.dispatcher.ep_world_size, 8)
127+
self.assertFalse(self.dispatcher.torchair_graph_enabled)
128+
self.assertFalse(self.dispatcher.with_quant)
129+
self.assertTrue(self.dispatcher.enable_dispatch_v2)
130+
self.assertTrue(self.dispatcher.need_extra_args)
131+
self.assertTrue(self.dispatcher.a3_need_extra_args)
132+
133+
def test_get_permute_mc2_kwargs_without_quant(self):
134+
"""测试 get_permute_mc2_kwargs(无量化)"""
135+
hidden_states = torch.randn(10, 128)
136+
topk_ids = torch.randint(0, 8, (10, 1))
137+
topk_weights = torch.randn(10, 1)
138+
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
139+
140+
kwargs = self.dispatcher.get_permute_mc2_kwargs(
141+
hidden_states, topk_weights, topk_ids, expert_map)
142+
self.assertIn("x", kwargs)
143+
self.assertIn("expert_ids", kwargs)
144+
self.assertEqual(kwargs["moe_expert_num"], 8)
145+
146+
def test_token_permutation_dispatch(self):
147+
"""测试 token_permutation(使用 dispatch)"""
148+
hidden_states = torch.randn(10, 128)
149+
topk_weights = torch.randn(10, 1)
150+
topk_ids = torch.randint(0, 8, (10, 1))
151+
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
152+
153+
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
154+
return_value=(torch.randn(10, 128), ) *
155+
5) as mock_dispatch:
156+
output = self.dispatcher.token_permutation(hidden_states,
157+
topk_weights, topk_ids,
158+
expert_map)
159+
mock_dispatch.assert_called_once()
160+
self.assertEqual(output[0], 1) # group_list_type == 1
161+
162+
def test_token_permutation_with_shared_experts_and_quant(self):
163+
"""测试 token_permutation(有 shared_experts 且 with_quant=True)"""
164+
self.shared_experts = mock.MagicMock()
165+
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
166+
torch.tensor(1.0))
167+
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
168+
self.dispatcher.with_quant = False
169+
self.dispatcher.shared_act = torch.randn(10, 128)
170+
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
171+
self.hidden_states = torch.randn(10, 128)
172+
self.topk_weights = torch.randn(10, 1)
173+
174+
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
175+
return_value=(torch.randn(10, 128), ) * 5):
176+
with mock.patch(
177+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
178+
autospec=True):
179+
with mock.patch(
180+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
181+
autospec=True) as mock_wait:
182+
self.dispatcher.token_permutation(
183+
self.hidden_states,
184+
self.topk_weights,
185+
torch.randint(0, 8, (10, 1)),
186+
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
187+
shared_experts=self.shared_experts)
188+
mock_wait.assert_any_call(self.hidden_states,
189+
self.topk_weights)
190+
191+
def test_get_unpermute_mc_kwargs_with_quant(self):
192+
"""测试 get_unpermute_mc_kwargs(with_quant=True)"""
193+
self.dispatcher.with_quant = True
194+
hidden_states = torch.randn(10, 128)
195+
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
196+
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
197+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
198+
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
199+
self.dispatcher.need_extra_args = True
200+
self.dispatcher.enable_dispatch_v2 = True
201+
202+
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
203+
self.assertIn("tp_send_counts", kwargs)
204+
205+
def test_token_unpermutation_with_shared_experts(self):
206+
self.dispatcher.shared_experts = mock.MagicMock()
207+
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
208+
10, 128), torch.tensor(1.0))
209+
self.dispatcher.shared_act = torch.randn(10, 128)
210+
self.dispatcher.with_quant = True
211+
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
212+
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
213+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
214+
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
215+
self.dispatcher.need_extra_args = True
216+
self.dispatcher.enable_dispatch_v2 = True
217+
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
218+
self.hidden_states = torch.randn(10, 128)
219+
220+
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
221+
return_value=torch.randn(10, 128)):
222+
with mock.patch(
223+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
224+
autospec=True):
225+
with mock.patch(
226+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
227+
autospec=True):
228+
self.dispatcher.token_unpermutation(self.hidden_states)
229+
230+
231+
class TestTokenDispatcherWithAllGather(unittest.TestCase):
232+
233+
def setUp(self):
234+
# Mock dependencies
235+
kwargs = {
236+
"apply_router_weight_on_input": False,
237+
"top_k": 2,
238+
"max_num_tokens": 100,
239+
"ep_size": 2,
240+
"num_experts": 128,
241+
"with_quant": False,
242+
}
243+
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
244+
245+
# Mock NPU functions
246+
self.patcher_moe_init_routing = mock.patch(
247+
'torch_npu.npu_moe_init_routing')
248+
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
249+
self.mock_moe_init_routing.return_value = (
250+
torch.randn(6, 128), # sorted_hidden_states
251+
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
252+
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
253+
)
254+
255+
self.patcher_moe_compute_expert_tokens = mock.patch(
256+
'torch_npu.npu_moe_compute_expert_tokens')
257+
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
258+
)
259+
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
260+
[3, 3]) # expert_tokens
261+
262+
self.patcher_moe_finalize_routing = mock.patch(
263+
'torch_npu.npu_moe_finalize_routing')
264+
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
265+
)
266+
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
267+
268+
def tearDown(self):
269+
self.patcher_moe_init_routing.stop()
270+
self.patcher_moe_compute_expert_tokens.stop()
271+
self.patcher_moe_finalize_routing.stop()
272+
273+
def test_token_permutation_with_expert_map(self):
274+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
275+
hidden_states = torch.randn(3, 128)
276+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
277+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
278+
279+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
280+
hidden_states, topk_weights, topk_ids, self.dispatcher.expert_map)
281+
282+
# Verify expert_map logic is used
283+
self.assertEqual(group_list_type, 0)
284+
self.assertTrue(sorted_hidden_states.shape, (6, 128))
285+
286+
# Check if sorting and filtering were applied
287+
self.assertIsNotNone(self.dispatcher.sorted_token_indices)
288+
self.assertIsNotNone(self.dispatcher.sorted_weights)
289+
290+
def test_token_permutation_without_expert_map(self):
291+
hidden_states = torch.randn(3, 128)
292+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
293+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
294+
295+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
296+
hidden_states, topk_weights, topk_ids, None)
297+
298+
# Verify npu_moe_init_routing is called
299+
self.mock_moe_init_routing.assert_called_once()
300+
args, kwargs = self.mock_moe_init_routing.call_args
301+
302+
self.assertEqual(group_list_type, 0)
303+
304+
def test_token_permutation_with_quant(self):
305+
kwargs = {
306+
"apply_router_weight_on_input": False,
307+
"top_k": 2,
308+
"max_num_tokens": 100,
309+
"ep_size": 2,
310+
"num_experts": 128,
311+
"with_quant": True,
312+
}
313+
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
314+
315+
hidden_states = torch.randn(3, 128)
316+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
317+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
318+
319+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation(
320+
hidden_states, topk_weights, topk_ids, None)
321+
322+
# Verify quant mode returns group_list_type=1
323+
self.assertEqual(group_list_type, 0)
324+
325+
def test_token_unpermutation_with_expert_map(self):
326+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
327+
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
328+
self.dispatcher.sorted_weights = torch.tensor(
329+
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
330+
self.dispatcher.original_shape = (3, 128)
331+
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
332+
hidden_states = torch.randn(6, 128)
333+
334+
final_hidden_states = self.dispatcher.token_unpermutation(
335+
hidden_states)
336+
337+
# Verify index_add_ is applied correctly
338+
self.assertEqual(final_hidden_states.shape, (3, 128))
339+
340+
def test_token_unpermutation_without_expert_map(self):
341+
self.dispatcher.with_quant = False
342+
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
343+
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
344+
self.dispatcher.original_shape = (3, 128)
345+
hidden_states = torch.randn(6, 128)
346+
347+
final_hidden_states = self.dispatcher.token_unpermutation(
348+
hidden_states)
349+
350+
# Verify npu_moe_finalize_routing is called
351+
self.mock_moe_finalize_routing.assert_called_once()
352+
args, kwargs = self.mock_moe_finalize_routing.call_args
353+
354+
self.assertEqual(final_hidden_states.shape, (3, 128))
355+
356+
def test_token_permutation_with_router_weight(self):
357+
self.dispatcher.apply_router_weight_on_input = True
358+
hidden_states = torch.randn(3, 128)
359+
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
360+
topk_ids = torch.tensor([[0], [1], [2]])
361+
362+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
363+
hidden_states, topk_weights, topk_ids, None)
364+
self.assertEqual(sorted_hidden_states.shape, (6, 128))
365+
366+
def test_token_permutation_invalid_topk_when_router_weight(self):
367+
self.dispatcher.apply_router_weight_on_input = True
368+
hidden_states = torch.randn(3, 128)
369+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
370+
371+
with self.assertRaises(AssertionError):
372+
self.dispatcher.token_permutation(
373+
hidden_states, topk_weights,
374+
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)

0 commit comments

Comments
 (0)