From 1870c8ef3e245098bbedcb9a225497f4d0c3b9c3 Mon Sep 17 00:00:00 2001 From: wangxiaoxin-sherie Date: Thu, 14 Aug 2025 12:36:39 +0800 Subject: [PATCH] refactor mc2/allgather tokendispatch. Signed-off-by: wangxiaoxin-sherie --- tests/ut/ops/test_token_dispatcher.py | 293 +++++++++- .../ops/moe_dispatcher/token_dispatcher.py | 511 +++++++++++++++++- 2 files changed, 802 insertions(+), 2 deletions(-) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 3a42b93c42..2a3383cd20 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -15,12 +15,17 @@ # limitations under the License. # This file is a part of the vllm-ascend project. +import unittest +from unittest import mock + import pytest +import torch from pytest_mock import MockerFixture from tests.ut.base import PytestBase from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) + AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, + TokenDispatcherWithAllGather, TokenDispatcherWithMC2) from vllm_ascend.utils import adapt_patch # noqa E402 @@ -63,3 +68,289 @@ def test_initialization(self, dispatcher, config): assert dispatcher.ep_rank == 0 assert dispatcher.ep_size == 2 assert dispatcher.overlap_stream is not None + + +class TestTokenDispatcherWithMC2(unittest.TestCase): + + def setUp(self): + self.mc2_group = mock.MagicMock() + self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123" + self.mc2_group.rank_in_group = 0 + self.mc2_group.world_size = 8 + self.mc2_group_patch = mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group", + return_value=self.mc2_group) + self.mc2_group_patch.start() + + self.rank_group_patch = mock.patch("torch.distributed.get_rank", + return_value=0) + self.rank_group_patch.start() + + # Mock get_forward_context().mc2_mask + self.forward_context = mock.MagicMock() + self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) + self.forward_context_patch = mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context", + return_value=self.forward_context) + self.forward_context_patch.start() + + # Mock get_ascend_soc_version() + self.ascend_soc_version_patch = mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version", + return_value=AscendSocVersion.A3) + self.ascend_soc_version_patch.start() + + # Mock get_ascend_config() + self.ascend_config = mock.MagicMock() + self.ascend_config.torchair_graph_config.enabled = False + self.ascend_config_patch = mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config", + return_value=self.ascend_config) + self.ascend_config_patch.start() + + kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} + self.dispatcher = TokenDispatcherWithMC2(**kwargs) + + def tearDown(self): + self.mc2_group_patch.stop() + self.forward_context_patch.stop() + self.ascend_soc_version_patch.stop() + self.ascend_config_patch.stop() + + def test_init(self): + # self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123") + self.assertEqual(self.dispatcher.ep_rank_id, 0) + self.assertEqual(self.dispatcher.ep_world_size, 8) + self.assertFalse(self.dispatcher.torchair_graph_enabled) + self.assertFalse(self.dispatcher.with_quant) + self.assertTrue(self.dispatcher.enable_dispatch_v2) + self.assertTrue(self.dispatcher.need_extra_args) + self.assertTrue(self.dispatcher.a3_need_extra_args) + + def test_get_permute_mc2_kwargs_without_quant(self): + hidden_states = torch.randn(10, 128) + topk_ids = torch.randint(0, 8, (10, 1)) + topk_weights = torch.randn(10, 1) + expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + + kwargs = self.dispatcher.get_permute_mc2_kwargs( + hidden_states, topk_weights, topk_ids, expert_map) + self.assertIn("x", kwargs) + self.assertIn("expert_ids", kwargs) + self.assertEqual(kwargs["moe_expert_num"], 8) + + def test_token_permutation_dispatch(self): + hidden_states = torch.randn(10, 128) + topk_weights = torch.randn(10, 1) + topk_ids = torch.randint(0, 8, (10, 1)) + expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + + with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", + return_value=(torch.randn(10, 128), ) * + 5) as mock_dispatch: + output = self.dispatcher.token_permutation(hidden_states, + topk_weights, topk_ids, + expert_map) + mock_dispatch.assert_called_once() + self.assertEqual(output[0], 1) # group_list_type == 1 + + def test_token_permutation_with_shared_experts_and_quant(self): + self.shared_experts = mock.MagicMock() + self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128), + torch.tensor(1.0)) + self.shared_experts.act_fn.return_value = torch.randn(10, 128) + self.dispatcher.with_quant = False + self.dispatcher.shared_act = torch.randn(10, 128) + self.dispatcher.swiglu_out_scale = torch.tensor(1.0) + self.hidden_states = torch.randn(10, 128) + self.topk_weights = torch.randn(10, 1) + + with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", + return_value=(torch.randn(10, 128), ) * 5): + with mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", + autospec=True): + with mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", + autospec=True) as mock_wait: + self.dispatcher.token_permutation( + self.hidden_states, + self.topk_weights, + torch.randint(0, 8, (10, 1)), + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + shared_experts=self.shared_experts) + mock_wait.assert_any_call(self.hidden_states, + self.topk_weights) + + def test_get_unpermute_mc_kwargs_with_quant(self): + self.dispatcher.with_quant = True + hidden_states = torch.randn(10, 128) + self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) + self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.need_extra_args = True + self.dispatcher.enable_dispatch_v2 = True + self.dispatcher.output = torch.randint(0, 8, (10, 1)) + + kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states) + self.assertIn("tp_send_counts", kwargs) + + def test_token_unpermutation_with_shared_experts(self): + self.dispatcher.shared_experts = mock.MagicMock() + self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( + 10, 128), torch.tensor(1.0)) + self.dispatcher.shared_act = torch.randn(10, 128) + self.dispatcher.with_quant = True + self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) + self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.need_extra_args = True + self.dispatcher.enable_dispatch_v2 = True + self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1)) + self.dispatcher.output = torch.randint(0, 8, (10, 1)) + self.hidden_states = torch.randn(10, 128) + + with mock.patch("torch_npu.npu_moe_distribute_combine_v2", + return_value=torch.randn(10, 128)): + with mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", + autospec=True): + with mock.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", + autospec=True): + self.dispatcher.token_unpermutation(self.hidden_states) + + +class TestTokenDispatcherWithAllGather(unittest.TestCase): + + def setUp(self): + # Mock dependencies + kwargs = { + "apply_router_weight_on_input": False, + "top_k": 2, + "max_num_tokens": 100, + "ep_size": 2, + "num_experts": 128, + "with_quant": False, + } + self.dispatcher = TokenDispatcherWithAllGather(**kwargs) + + # Mock NPU functions + self.patcher_moe_init_routing = mock.patch( + 'torch_npu.npu_moe_init_routing') + self.mock_moe_init_routing = self.patcher_moe_init_routing.start() + self.mock_moe_init_routing.return_value = ( + torch.randn(6, 128), # sorted_hidden_states + torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx + torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx + ) + + self.patcher_moe_compute_expert_tokens = mock.patch( + 'torch_npu.npu_moe_compute_expert_tokens') + self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start( + ) + self.mock_moe_compute_expert_tokens.return_value = torch.tensor( + [3, 3]) # expert_tokens + + self.patcher_moe_finalize_routing = mock.patch( + 'torch_npu.npu_moe_finalize_routing') + self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( + ) + self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) + + def tearDown(self): + self.patcher_moe_init_routing.stop() + self.patcher_moe_compute_expert_tokens.stop() + self.patcher_moe_finalize_routing.stop() + + def test_token_permutation_without_expert_map(self): + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( + hidden_states, topk_weights, topk_ids, None) + + # Verify npu_moe_init_routing is called + self.mock_moe_init_routing.assert_called_once() + args, kwargs = self.mock_moe_init_routing.call_args + + self.assertEqual(group_list_type, 0) + + def test_token_permutation_with_quant(self): + kwargs = { + "apply_router_weight_on_input": False, + "top_k": 2, + "max_num_tokens": 100, + "ep_size": 2, + "num_experts": 128, + "with_quant": True, + } + self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) + + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation( + hidden_states, topk_weights, topk_ids, None) + + # Verify quant mode returns group_list_type=1 + self.assertEqual(group_list_type, 0) + + def test_token_unpermutation_with_expert_map(self): + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) + self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.sorted_weights = torch.tensor( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + self.dispatcher.original_shape = (3, 128) + self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) + hidden_states = torch.randn(6, 128) + + final_hidden_states = self.dispatcher.token_unpermutation( + hidden_states) + + # Verify index_add_ is applied correctly + self.assertEqual(final_hidden_states.shape, (3, 128)) + + def test_token_unpermutation_without_expert_map(self): + self.dispatcher.with_quant = False + self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.sorted_weights = torch.tensor( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + self.dispatcher.original_shape = (3, 128) + self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) + hidden_states = torch.randn(6, 128) + + final_hidden_states = self.dispatcher.token_unpermutation( + hidden_states) + + # Verify npu_moe_finalize_routing is called + self.mock_moe_finalize_routing.assert_called_once() + args, kwargs = self.mock_moe_finalize_routing.call_args + + self.assertEqual(final_hidden_states.shape, (3, 128)) + + def test_token_permutation_with_router_weight(self): + self.dispatcher.apply_router_weight_on_input = True + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 + topk_ids = torch.tensor([[0], [1], [2]]) + + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( + hidden_states, topk_weights, topk_ids, None) + self.assertEqual(sorted_hidden_states.shape, (6, 128)) + + def test_token_permutation_invalid_topk_when_router_weight(self): + self.dispatcher.apply_router_weight_on_input = True + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + + with self.assertRaises(AssertionError): + self.dispatcher.token_permutation( + hidden_states, topk_weights, + torch.tensor([[0, 1], [1, 2], [2, 3]]), None) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 402e8fb93a..b3cc46c4fd 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -20,17 +20,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + +from abc import ABC, abstractmethod +from typing import Any, Optional import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group +from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all +from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version class MoEDispatcherConfig: @@ -451,3 +458,505 @@ def alltoall_token_unpermutation2(permutated_local_input_tokens): self.num_global_tokens_per_local_expert_cpu = None return output, None + + +class MoETokenDispatcher(ABC): + + def __init__(self, **kwargs) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self.top_k = kwargs.get("top_k") + self.num_experts = kwargs.get("num_experts") + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_ep_group().device_group + + @property + def ep_rank(self): + return get_ep_group().rank_in_group + + @property + def ep_size(self): + return get_ep_group().world_size + + @abstractmethod + def token_permutation( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + shared_experts: Optional[Any] = None, + ): + raise NotImplementedError("Dispatch function not implemented.") + + @abstractmethod + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + raise NotImplementedError("Restore function not implemented.") + + +class TokenDispatcherWithMC2(MoETokenDispatcher): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + self.ep_rank_id = get_mc2_group().rank_in_group + self.ep_world_size = get_mc2_group().world_size + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.with_quant = kwargs.get("with_quant") + self.enable_dispatch_v2 = hasattr(torch_npu, + "npu_moe_distribute_dispatch_v2") + self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or self.torchair_graph_enabled) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + self.a3_need_extra_args = \ + get_ascend_soc_version() == AscendSocVersion.A3 + self.output = None + self.dynamic_scale = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.shared_act = None + self.topk_ids = None + self.topk_weights = None + self.shared_experts = None + + def get_permute_mc2_kwargs(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + global_redundant_expert_num: int = 0): + quant_mode = 0 + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + if self.with_quant: + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num + else: + moe_expert_num = len(expert_map) + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + } + if self.need_extra_args: + stage1_kwargs.update({ + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.a3_need_extra_args and self.enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) + + kwargs_mc2.update(stage1_kwargs) + return kwargs_mc2 + + def token_permutation( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + shared_experts: Optional[Any] = None, + ): + self.expert_map = expert_map + self.topk_ids = topk_ids + self.topk_weights = topk_weights + self.shared_experts = shared_experts + + kwargs_mc2 = self.get_permute_mc2_kwargs(hidden_states, topk_weights, + topk_ids, expert_map, + global_redundant_expert_num) + self.output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, self.dynamic_scale, self.assist_info_for_combine, \ + expert_token_nums, self.ep_recv_counts = self.output[0:5] + + if self.with_quant: + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_gate_up, expand_x) + shared_act_out = shared_experts.act_fn( + (shared_gate_up, shared_dequant_scale)) + self.shared_act, self.swiglu_out_scale = \ + shared_act_out[0], shared_act_out[1] + + else: + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(hidden_states, topk_weights) + shared_gate_up, _ = shared_experts.gate_up_proj( + hidden_states) + npu_wait_tensor(shared_gate_up, expand_x) + self.shared_act = shared_experts.act_fn(shared_gate_up) + group_list_type = 1 + return group_list_type, expand_x, expert_token_nums + + def get_unpermute_mc_kwargs(self, hidden_states: torch.Tensor): + assert self.expert_map is not None + assert self.topk_weights is not None + assert self.topk_ids is not None + assert self.output is not None + moe_expert_num = len(self.expert_map) + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + # moeCombine + kwargs_mc2 = { + "expand_x": hidden_states, + "expert_ids": self.topk_ids, + "expert_scales": self.topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + if self.with_quant: + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + else: + tp_recv_counts = self.output[5] + stage3_kwargs = { + "ep_send_counts": self.ep_recv_counts, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + } + if self.enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + self.assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": self.assist_info_for_combine, + }) + if self.need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.a3_need_extra_args and self.enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage3_kwargs) + return kwargs_mc2 + + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + + kwargs_mc2 = self.get_unpermute_mc_kwargs(hidden_states) + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) + if self.shared_experts is None: + return hidden_states + else: + if self.with_quant: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(self.shared_act, hidden_states) + shared_hidden_states, _ = self.shared_experts.down_proj( + (self.shared_act, self.swiglu_out_scale)) + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(self.shared_act, hidden_states) + shared_hidden_states, _ = self.shared_experts.down_proj( + self.shared_act) + return hidden_states, shared_hidden_states + + +class TokenDispatcherWithAllGather(MoETokenDispatcher): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.apply_router_weight_on_input = kwargs.get( + "apply_router_weight_on_input") + self.top_k = kwargs.get("top_k") + self.max_num_tokens = kwargs.get("max_num_tokens") + ep_size = kwargs.get("ep_size") + if ep_size is not None: + self.num_experts_local = self.num_experts // ep_size + self.with_quant = kwargs.get("with_quant") + self.sorted_weights = None + self.expanded_row_idx = None + self.sorted_token_indices = None + self.original_shape = None + self.mask = None + self.expert_map = None + self.topk_weights = None + self.topk_ids = None + + def token_permutation( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + shared_experts: Optional[Any] = None, + ): + self.original_shape = hidden_states.shape + # assert len(original_shape) == 2 + + num_tokens = hidden_states.shape[:-1].numel() + dtype = hidden_states.dtype + device = hidden_states.device + self.expert_map = expert_map + self.topk_weights = topk_weights + self.topk_ids = topk_ids + # assert dtype in [torch.float32, torch.float16, torch.bfloat16 + # ], "Only float32, float16, and bfsloat16 are supported" + + if self.apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * \ + topk_weights.to(hidden_states.dtype) + + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange( + num_tokens, device=device, + dtype=torch.int64).unsqueeze(1).expand(-1, + self.top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + self.mask = local_experts_flat != -1 + filtered_weights = torch.where( + self.mask, weights_flat, + torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + self.mask, local_experts_flat, + torch.full_like(local_experts_flat, + self.num_experts_local)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) + self.sorted_token_indices = token_indices[sort_indices] + self.sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(self.num_experts_local + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), + ones) + token_counts = token_counts[:self.num_experts_local] + + # Rearrange hidden_states + sorted_hidden_states = hidden_states[self.sorted_token_indices] + if self.with_quant: + group_list_type = 1 + else: + expert_tokens = torch.cumsum(token_counts, + dim=0, + dtype=torch.int64) + group_list_type = 0 + else: + row_idx_len = num_tokens * self.top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(self.top_k, + -1).permute( + 1, 0).contiguous()) + active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens + sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=active_num) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, self.num_experts_local) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + return group_list_type, sorted_hidden_states, expert_tokens + + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + assert self.mask is not None + assert self.sorted_token_indices is not None + assert self.sorted_weights is not None + assert self.original_shape is not None + dtype = hidden_states.dtype + device = hidden_states.device + if self.expert_map is not None: + weighted_down_out = hidden_states * \ + self.sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros(*self.original_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + + # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # This created multiple NaN and index_add_ will mix them up which harms accuracy + # remove this mask and filter after it being fixed + num_valid_tokens = self.mask.sum() + valid_token_mask = torch.arange( + 0, self.sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + valid_output = torch.where( + valid_token_mask, weighted_down_out, + torch.zeros_like(weighted_down_out)).to(dtype) + final_hidden_states.index_add_(0, self.sorted_token_indices, + valid_output) + else: + if self.with_quant: + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=self.topk_weights, + expanded_src_to_dst_row=self.expanded_row_idx, + export_for_source_row=self.topk_ids, + ) + if len(self.original_shape) == 3: + final_hidden_states = final_hidden_states.view( + self.original_shape) + else: + scales = torch.ones_like( + self.topk_weights + ) if self.apply_router_weight_on_input else self.topk_weights + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=scales, + expanded_src_to_dst_row=self.expanded_row_idx, + export_for_source_row=self.topk_ids, + ) + + return final_hidden_states + + +# mypy: disable-error-code="override" +class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): + + def __init__(self, **kwargs): + super(MoETokenDispatcher, self).__init__(**kwargs) + self.apply_router_weight_on_input = kwargs.get( + "apply_router_weight_on_input") + ep_size = kwargs.get("ep_size") + self.local_ep = ep_size + self.top_k = kwargs.get("top_k") + assert self.local_ep is not None + self.local_num_experts = self.num_experts // self.local_ep + self.local_num_group = self.top_k // self.local_ep + self.bsz = None + + def token_permutation( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + shared_experts: Optional[Any] = None, + ): + + if self.apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * \ + topk_weights.to(hidden_states.dtype) + + self.bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) + self.sorted_hidden_states = hidden_states.index_select( + 0, self.sorted_topk_ids // self.local_num_group) + + experts_id = torch.arange(0, + self.local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = ( + flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + self.topk_scales = topk_weights.view(-1).index_select( + 0, self.sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + return hidden_states, group_list + + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + assert self.local_ep is not None + unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( + torch.int32) + unsorted_hidden_states = hidden_states.index_select( + 0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + self.bsz, self.top_k // self.local_ep, -1).sum(1) + return final_hidden_states