Skip to content

Commit 1870c8e

Browse files
author
wangxiaoxin-sherie
committed
refactor mc2/allgather tokendispatch.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 1327f9b commit 1870c8e

File tree

2 files changed

+802
-2
lines changed

2 files changed

+802
-2
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 292 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
# limitations under the License.
1616
# This file is a part of the vllm-ascend project.
1717

18+
import unittest
19+
from unittest import mock
20+
1821
import pytest
22+
import torch
1923
from pytest_mock import MockerFixture
2024

2125
from tests.ut.base import PytestBase
2226
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
23-
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
27+
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
28+
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
2429
from vllm_ascend.utils import adapt_patch # noqa E402
2530

2631

@@ -63,3 +68,289 @@ def test_initialization(self, dispatcher, config):
6368
assert dispatcher.ep_rank == 0
6469
assert dispatcher.ep_size == 2
6570
assert dispatcher.overlap_stream is not None
71+
72+
73+
class TestTokenDispatcherWithMC2(unittest.TestCase):
74+
75+
def setUp(self):
76+
self.mc2_group = mock.MagicMock()
77+
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
78+
self.mc2_group.rank_in_group = 0
79+
self.mc2_group.world_size = 8
80+
self.mc2_group_patch = mock.patch(
81+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
82+
return_value=self.mc2_group)
83+
self.mc2_group_patch.start()
84+
85+
self.rank_group_patch = mock.patch("torch.distributed.get_rank",
86+
return_value=0)
87+
self.rank_group_patch.start()
88+
89+
# Mock get_forward_context().mc2_mask
90+
self.forward_context = mock.MagicMock()
91+
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
92+
self.forward_context_patch = mock.patch(
93+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
94+
return_value=self.forward_context)
95+
self.forward_context_patch.start()
96+
97+
# Mock get_ascend_soc_version()
98+
self.ascend_soc_version_patch = mock.patch(
99+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
100+
return_value=AscendSocVersion.A3)
101+
self.ascend_soc_version_patch.start()
102+
103+
# Mock get_ascend_config()
104+
self.ascend_config = mock.MagicMock()
105+
self.ascend_config.torchair_graph_config.enabled = False
106+
self.ascend_config_patch = mock.patch(
107+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
108+
return_value=self.ascend_config)
109+
self.ascend_config_patch.start()
110+
111+
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
112+
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
113+
114+
def tearDown(self):
115+
self.mc2_group_patch.stop()
116+
self.forward_context_patch.stop()
117+
self.ascend_soc_version_patch.stop()
118+
self.ascend_config_patch.stop()
119+
120+
def test_init(self):
121+
# self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123")
122+
self.assertEqual(self.dispatcher.ep_rank_id, 0)
123+
self.assertEqual(self.dispatcher.ep_world_size, 8)
124+
self.assertFalse(self.dispatcher.torchair_graph_enabled)
125+
self.assertFalse(self.dispatcher.with_quant)
126+
self.assertTrue(self.dispatcher.enable_dispatch_v2)
127+
self.assertTrue(self.dispatcher.need_extra_args)
128+
self.assertTrue(self.dispatcher.a3_need_extra_args)
129+
130+
def test_get_permute_mc2_kwargs_without_quant(self):
131+
hidden_states = torch.randn(10, 128)
132+
topk_ids = torch.randint(0, 8, (10, 1))
133+
topk_weights = torch.randn(10, 1)
134+
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
135+
136+
kwargs = self.dispatcher.get_permute_mc2_kwargs(
137+
hidden_states, topk_weights, topk_ids, expert_map)
138+
self.assertIn("x", kwargs)
139+
self.assertIn("expert_ids", kwargs)
140+
self.assertEqual(kwargs["moe_expert_num"], 8)
141+
142+
def test_token_permutation_dispatch(self):
143+
hidden_states = torch.randn(10, 128)
144+
topk_weights = torch.randn(10, 1)
145+
topk_ids = torch.randint(0, 8, (10, 1))
146+
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
147+
148+
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
149+
return_value=(torch.randn(10, 128), ) *
150+
5) as mock_dispatch:
151+
output = self.dispatcher.token_permutation(hidden_states,
152+
topk_weights, topk_ids,
153+
expert_map)
154+
mock_dispatch.assert_called_once()
155+
self.assertEqual(output[0], 1) # group_list_type == 1
156+
157+
def test_token_permutation_with_shared_experts_and_quant(self):
158+
self.shared_experts = mock.MagicMock()
159+
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
160+
torch.tensor(1.0))
161+
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
162+
self.dispatcher.with_quant = False
163+
self.dispatcher.shared_act = torch.randn(10, 128)
164+
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
165+
self.hidden_states = torch.randn(10, 128)
166+
self.topk_weights = torch.randn(10, 1)
167+
168+
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
169+
return_value=(torch.randn(10, 128), ) * 5):
170+
with mock.patch(
171+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
172+
autospec=True):
173+
with mock.patch(
174+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
175+
autospec=True) as mock_wait:
176+
self.dispatcher.token_permutation(
177+
self.hidden_states,
178+
self.topk_weights,
179+
torch.randint(0, 8, (10, 1)),
180+
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
181+
shared_experts=self.shared_experts)
182+
mock_wait.assert_any_call(self.hidden_states,
183+
self.topk_weights)
184+
185+
def test_get_unpermute_mc_kwargs_with_quant(self):
186+
self.dispatcher.with_quant = True
187+
hidden_states = torch.randn(10, 128)
188+
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
189+
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
190+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
191+
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
192+
self.dispatcher.need_extra_args = True
193+
self.dispatcher.enable_dispatch_v2 = True
194+
self.dispatcher.output = torch.randint(0, 8, (10, 1))
195+
196+
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
197+
self.assertIn("tp_send_counts", kwargs)
198+
199+
def test_token_unpermutation_with_shared_experts(self):
200+
self.dispatcher.shared_experts = mock.MagicMock()
201+
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
202+
10, 128), torch.tensor(1.0))
203+
self.dispatcher.shared_act = torch.randn(10, 128)
204+
self.dispatcher.with_quant = True
205+
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
206+
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
207+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
208+
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
209+
self.dispatcher.need_extra_args = True
210+
self.dispatcher.enable_dispatch_v2 = True
211+
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
212+
self.dispatcher.output = torch.randint(0, 8, (10, 1))
213+
self.hidden_states = torch.randn(10, 128)
214+
215+
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
216+
return_value=torch.randn(10, 128)):
217+
with mock.patch(
218+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
219+
autospec=True):
220+
with mock.patch(
221+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
222+
autospec=True):
223+
self.dispatcher.token_unpermutation(self.hidden_states)
224+
225+
226+
class TestTokenDispatcherWithAllGather(unittest.TestCase):
227+
228+
def setUp(self):
229+
# Mock dependencies
230+
kwargs = {
231+
"apply_router_weight_on_input": False,
232+
"top_k": 2,
233+
"max_num_tokens": 100,
234+
"ep_size": 2,
235+
"num_experts": 128,
236+
"with_quant": False,
237+
}
238+
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
239+
240+
# Mock NPU functions
241+
self.patcher_moe_init_routing = mock.patch(
242+
'torch_npu.npu_moe_init_routing')
243+
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
244+
self.mock_moe_init_routing.return_value = (
245+
torch.randn(6, 128), # sorted_hidden_states
246+
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
247+
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
248+
)
249+
250+
self.patcher_moe_compute_expert_tokens = mock.patch(
251+
'torch_npu.npu_moe_compute_expert_tokens')
252+
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
253+
)
254+
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
255+
[3, 3]) # expert_tokens
256+
257+
self.patcher_moe_finalize_routing = mock.patch(
258+
'torch_npu.npu_moe_finalize_routing')
259+
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
260+
)
261+
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
262+
263+
def tearDown(self):
264+
self.patcher_moe_init_routing.stop()
265+
self.patcher_moe_compute_expert_tokens.stop()
266+
self.patcher_moe_finalize_routing.stop()
267+
268+
def test_token_permutation_without_expert_map(self):
269+
hidden_states = torch.randn(3, 128)
270+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
271+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
272+
273+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
274+
hidden_states, topk_weights, topk_ids, None)
275+
276+
# Verify npu_moe_init_routing is called
277+
self.mock_moe_init_routing.assert_called_once()
278+
args, kwargs = self.mock_moe_init_routing.call_args
279+
280+
self.assertEqual(group_list_type, 0)
281+
282+
def test_token_permutation_with_quant(self):
283+
kwargs = {
284+
"apply_router_weight_on_input": False,
285+
"top_k": 2,
286+
"max_num_tokens": 100,
287+
"ep_size": 2,
288+
"num_experts": 128,
289+
"with_quant": True,
290+
}
291+
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
292+
293+
hidden_states = torch.randn(3, 128)
294+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
295+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
296+
297+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation(
298+
hidden_states, topk_weights, topk_ids, None)
299+
300+
# Verify quant mode returns group_list_type=1
301+
self.assertEqual(group_list_type, 0)
302+
303+
def test_token_unpermutation_with_expert_map(self):
304+
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
305+
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
306+
self.dispatcher.sorted_weights = torch.tensor(
307+
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
308+
self.dispatcher.original_shape = (3, 128)
309+
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
310+
hidden_states = torch.randn(6, 128)
311+
312+
final_hidden_states = self.dispatcher.token_unpermutation(
313+
hidden_states)
314+
315+
# Verify index_add_ is applied correctly
316+
self.assertEqual(final_hidden_states.shape, (3, 128))
317+
318+
def test_token_unpermutation_without_expert_map(self):
319+
self.dispatcher.with_quant = False
320+
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
321+
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
322+
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
323+
self.dispatcher.sorted_weights = torch.tensor(
324+
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
325+
self.dispatcher.original_shape = (3, 128)
326+
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
327+
hidden_states = torch.randn(6, 128)
328+
329+
final_hidden_states = self.dispatcher.token_unpermutation(
330+
hidden_states)
331+
332+
# Verify npu_moe_finalize_routing is called
333+
self.mock_moe_finalize_routing.assert_called_once()
334+
args, kwargs = self.mock_moe_finalize_routing.call_args
335+
336+
self.assertEqual(final_hidden_states.shape, (3, 128))
337+
338+
def test_token_permutation_with_router_weight(self):
339+
self.dispatcher.apply_router_weight_on_input = True
340+
hidden_states = torch.randn(3, 128)
341+
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
342+
topk_ids = torch.tensor([[0], [1], [2]])
343+
344+
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
345+
hidden_states, topk_weights, topk_ids, None)
346+
self.assertEqual(sorted_hidden_states.shape, (6, 128))
347+
348+
def test_token_permutation_invalid_topk_when_router_weight(self):
349+
self.dispatcher.apply_router_weight_on_input = True
350+
hidden_states = torch.randn(3, 128)
351+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
352+
353+
with self.assertRaises(AssertionError):
354+
self.dispatcher.token_permutation(
355+
hidden_states, topk_weights,
356+
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)

0 commit comments

Comments
 (0)