Skip to content

Commit a6bb502

Browse files
authored
[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it? This method replaces the previous all-gather approach for small numbers of tokens. The key changes include: - A new `AscendFusedMoE` layer that handles token splitting, local computation, and final aggregation via all-gather. - Logic in the model runner to dynamically select between the new MC2 method and the existing all-gather method based on the number of input tokens. - Sharding the MoE communication mask across tensor-parallel ranks. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Test case fixed. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@b00e69f --------- Signed-off-by: Yizhou Liu <[email protected]>
1 parent 5d8ec28 commit a6bb502

File tree

11 files changed

+502
-406
lines changed

11 files changed

+502
-406
lines changed

tests/e2e/multicard/moe/test_moe_comm.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,30 @@
1818

1919
import pytest
2020
import torch
21-
from transformers import PretrainedConfig
22-
from vllm import forward_context
2321

24-
from vllm_ascend.distributed import moe_comm_method
25-
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
26-
NativeAllGatherCommImpl)
22+
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
23+
FusedMoEConfig, FusedMoEParallelConfig)
24+
25+
from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
26+
AllGatherCommImpl, NativeAllGatherCommImpl)
2727

2828

2929
@pytest.mark.parametrize("num_tokens", [16, 128])
3030
@pytest.mark.parametrize("hidden_size", [64, 128])
3131
@pytest.mark.parametrize("global_num_experts", [8, 16])
32+
@pytest.mark.parametrize("num_local_experts", [4, 8])
3233
@pytest.mark.parametrize("top_k_num", [2, 4])
3334
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
34-
@pytest.mark.parametrize("num_local_experts", [4, 8])
3535
@pytest.mark.parametrize("ep_rank", [0, 1])
3636
def test_all_gather_comm_impl(
3737
num_tokens,
3838
hidden_size,
3939
global_num_experts,
40+
num_local_experts,
4041
top_k_num,
4142
dtype,
42-
num_local_experts,
4343
ep_rank,
44+
mocker,
4445
):
4546
"""
4647
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
@@ -56,23 +57,37 @@ def test_all_gather_comm_impl(
5657
"num_local_experts cannot be greater than global_num_experts")
5758

5859
device = torch.device("npu")
59-
hf_config = PretrainedConfig(
60-
num_experts_per_tok=top_k_num,
60+
61+
# mock get_tensor_model_parallel_rank to return ep_rank
62+
mocker.patch(
63+
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
64+
return_value=ep_rank,
65+
)
66+
67+
# make moe config
68+
parallel_config = SimpleNamespace(
69+
enable_expert_parallel=num_local_experts < global_num_experts)
70+
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
71+
tp_size_=max(2, global_num_experts // num_local_experts),
72+
dp_size_=1,
73+
vllm_parallel_config=parallel_config,
74+
)
75+
76+
moe_config = FusedMoEConfig(
6177
num_experts=global_num_experts,
78+
experts_per_token=top_k_num,
79+
hidden_dim=hidden_size,
80+
num_local_experts=num_local_experts,
81+
moe_parallel_config=moe_parallel_config,
82+
in_dtype=dtype,
83+
quant_config=None, # No quantization in this test
84+
max_num_tokens=num_tokens,
6285
)
6386

6487
# Instantiate implementations
65-
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
66-
67-
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
88+
native_impl = NativeAllGatherCommImpl(moe_config)
6889

69-
# TODO: Find out if this is the correct way to mock the forward context and ep group
70-
# Mock get_forward_context to return an object with moe_comm_method
71-
forward_context._forward_context = SimpleNamespace(
72-
moe_comm_method=all_gather_impl)
73-
# Mock get_ep_group to return a fake group with the specified ep_rank
74-
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
75-
moe_comm_method.get_ep_group = lambda: fake_ep_group
90+
all_gather_impl = AllGatherCommImpl(moe_config)
7691

7792
# --- Input Data ---
7893
hidden_states = torch.randn(num_tokens,
@@ -103,27 +118,26 @@ def test_all_gather_comm_impl(
103118
native_permuted_hidden,
104119
native_expert_tokens,
105120
_,
106-
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
107-
expert_map, num_experts)
121+
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
122+
num_experts)
108123
# Simulate MLP output
109124
native_mlp_output = torch.randn_like(native_permuted_hidden)
110-
native_impl._post_process(native_mlp_output, native_hidden_states_out)
125+
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
111126

112127
# --- Run AllGather Implementation ---
113128
all_gather_hidden_states_out = hidden_states.clone()
114129
(
115130
all_gather_permuted_hidden,
116131
all_gather_expert_tokens,
117132
_,
118-
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
119-
topk_weights, expert_map,
120-
num_experts)
133+
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
134+
expert_map, num_experts)
121135

122136
# Use the same simulated MLP output for a fair comparison
123137
all_gather_mlp_output = native_mlp_output.clone()
124138

125-
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
126-
all_gather_hidden_states_out)
139+
all_gather_impl.unpermute(all_gather_mlp_output,
140+
all_gather_hidden_states_out)
127141

128142
# --- Assertions ---
129143
# Define tolerance based on dtype

tests/ut/distributed/test_communicator.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import MagicMock, Mock, patch
2+
from unittest.mock import MagicMock, patch
33

44
import torch
55
import torch.distributed as dist
@@ -87,69 +87,3 @@ def patched_all_to_all(output_tensor_list,
8787
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
8888

8989
assert output.tolist() == [[10, 20], [50, 60]]
90-
91-
@patch("vllm.config.get_current_vllm_config", return_value=None)
92-
@patch("torch.npu.current_device", return_value=MagicMock())
93-
@patch("torch.npu.set_device", return_value=MagicMock())
94-
@patch("torch.distributed.get_process_group_ranks",
95-
return_value={
96-
0: 0,
97-
1: 1
98-
})
99-
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
100-
@patch("torch.distributed.is_initialized", return_value=True)
101-
@patch("torch.distributed.get_rank", return_value=1)
102-
@patch("torch.distributed.is_initialized", return_value=True)
103-
@patch("torch.distributed.get_backend", return_value="hccl")
104-
@patch("torch.distributed.get_rank", return_value=1)
105-
@patch("torch.distributed.get_world_size", return_value=2)
106-
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
107-
@patch("torch.npu.device")
108-
def test_dispatch(self, *_):
109-
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
110-
comm.all2all_manager = Mock()
111-
hidden_states = torch.randn(2, 4, 8)
112-
router_logits = torch.randn(2, 4, 2)
113-
114-
mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
115-
comm.all2all_manager.dispatch.return_value = mock_dispatch_result
116-
117-
result_hidden, result_logits = comm.dispatch(hidden_states,
118-
router_logits)
119-
120-
assert torch.allclose(result_hidden, mock_dispatch_result[0])
121-
assert torch.allclose(result_logits, mock_dispatch_result[1])
122-
123-
comm.all2all_manager.dispatch.assert_called_once_with(
124-
hidden_states, router_logits)
125-
126-
@patch("vllm.config.get_current_vllm_config", return_value=None)
127-
@patch("torch.npu.current_device", return_value=MagicMock())
128-
@patch("torch.npu.set_device", return_value=MagicMock())
129-
@patch("torch.distributed.get_process_group_ranks",
130-
return_value={
131-
0: 0,
132-
1: 1
133-
})
134-
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
135-
@patch("torch.distributed.is_initialized", return_value=True)
136-
@patch("torch.distributed.get_rank", return_value=1)
137-
@patch("torch.distributed.is_initialized", return_value=True)
138-
@patch("torch.distributed.get_backend", return_value="hccl")
139-
@patch("torch.distributed.get_rank", return_value=1)
140-
@patch("torch.distributed.get_world_size", return_value=2)
141-
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
142-
@patch("torch.npu.device")
143-
def test_combine(self, *_):
144-
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
145-
comm.all2all_manager = Mock()
146-
hidden_states = torch.randn(2, 4, 8)
147-
148-
mock_combine_result = torch.randn(2, 4, 8)
149-
comm.all2all_manager.combine.return_value = mock_combine_result
150-
151-
result = comm.combine(hidden_states)
152-
153-
assert torch.allclose(result, mock_combine_result)
154-
155-
comm.all2all_manager.combine.assert_called_once_with(hidden_states)

tests/ut/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
289289
# ascend custom op is not registered
290290
utils.register_ascend_customop()
291291
# should call register_oot three
292-
self.assertEqual(mock_customop.register_oot.call_count, 8)
292+
self.assertEqual(mock_customop.register_oot.call_count, 9)
293293
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
294294

295295
# ascend custom op is already registered
296296
utils.register_ascend_customop()
297297
# should not register_oot again, thus only called three in this ut
298-
self.assertEqual(mock_customop.register_oot.call_count, 8)
298+
self.assertEqual(mock_customop.register_oot.call_count, 9)
299299

300300

301301
class TestProfileExecuteDuration(TestBase):

vllm_ascend/ascend_forward_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
set_forward_context)
1212

1313
import vllm_ascend.envs as envs_ascend
14-
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
1514

1615

1716
class FusedMoEState(Enum):
@@ -57,7 +56,7 @@ def set_ascend_forward_context(
5756
with_prefill: bool = True,
5857
in_profile_run: bool = False,
5958
reserved_mc2_mask: Optional[torch.Tensor] = None,
60-
moe_comm_method: Optional[MoECommMethod] = None,
59+
moe_comm_method: str = "",
6160
num_actual_tokens: Optional[int] = None,
6261
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
6362
batch_descriptor: Optional[BatchDescriptor] = None):
@@ -75,7 +74,7 @@ def set_ascend_forward_context(
7574
batch_descriptor=batch_descriptor,
7675
):
7776
forward_context = get_forward_context()
78-
forward_context.moe_comm_method = moe_comm_method
77+
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
7978
forward_context.with_prefill = with_prefill
8079
ep_size = (get_ep_group().world_size if
8180
vllm_config.parallel_config.enable_expert_parallel else 1)

vllm_ascend/distributed/communicator.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
23-
from vllm.utils import logger
2423

2524

2625
class NPUCommunicator(DeviceCommunicatorBase):
@@ -35,12 +34,6 @@ def __init__(self,
3534
# init device according to rank
3635
self.device = torch.npu.current_device()
3736

38-
if self.use_all2all:
39-
from vllm.distributed.device_communicators.all2all import \
40-
NaiveAll2AllManager
41-
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
42-
logger.info("Using naive all2all manager.")
43-
4437
def all_to_all(self,
4538
input_: torch.Tensor,
4639
scatter_dim: int = 0,
@@ -80,17 +73,3 @@ def all_to_all(self,
8073
dist.all_to_all(output_list, input_list, group=self.device_group)
8174
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
8275
return output_tensor
83-
84-
# TODO: Add ut for dispatch and combine
85-
def dispatch(
86-
self, hidden_states: torch.Tensor,
87-
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
88-
assert self.all2all_manager is not None
89-
hidden_states, router_logits = self.all2all_manager.dispatch(
90-
hidden_states, router_logits)
91-
return hidden_states, router_logits
92-
93-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
94-
assert self.all2all_manager is not None
95-
hidden_states = self.all2all_manager.combine(hidden_states)
96-
return hidden_states

0 commit comments

Comments
 (0)