Skip to content

Commit 992271b

Browse files
authored
[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)
### What this PR does / why we need it? This PR refactors the MoE (Mixture of Experts) communication logic by introducing a strategy pattern. It defines an abstract base class, `MoECommMethod`, which encapsulates different communication strategies for MoE layers. By decoupling the MoE implementation from any single communication method, this change makes it simpler to add, replace, or optimize communication strategies in the future. Plan / Roadmap 1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL Graph handling to cover all scenarios (this PR). 2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize performance in specific scenarios. 3. Enable W8A8 / Int8 models to use `unified_fused_experts`. Other notes * Data-parallel (DP) communication currently does not work with vLLM's dispatch/combine mechanisms; an alternative approach is required to resolve this incompatibility. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@f7ad6a1 --------- Signed-off-by: Yizhou Liu <[email protected]>
1 parent 1a70564 commit 992271b

File tree

7 files changed

+764
-26
lines changed

7 files changed

+764
-26
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2023 The vLLM team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
17+
from types import SimpleNamespace
18+
19+
import pytest
20+
import torch
21+
from transformers import PretrainedConfig
22+
from vllm import forward_context
23+
24+
from vllm_ascend.distributed import moe_comm_method
25+
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
26+
NativeAllGatherCommImpl)
27+
28+
29+
@pytest.mark.parametrize("num_tokens", [16, 128])
30+
@pytest.mark.parametrize("hidden_size", [64, 128])
31+
@pytest.mark.parametrize("global_num_experts", [8, 16])
32+
@pytest.mark.parametrize("top_k_num", [2, 4])
33+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
34+
@pytest.mark.parametrize("num_local_experts", [4, 8])
35+
@pytest.mark.parametrize("ep_rank", [0, 1])
36+
def test_all_gather_comm_impl(
37+
num_tokens,
38+
hidden_size,
39+
global_num_experts,
40+
top_k_num,
41+
dtype,
42+
num_local_experts,
43+
ep_rank,
44+
):
45+
"""
46+
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
47+
48+
This test compares the outputs of the NPU-optimized AllGatherCommImpl
49+
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
50+
correctness across various configurations.
51+
"""
52+
if top_k_num > global_num_experts:
53+
pytest.skip("top_k_num cannot be greater than global_num_experts")
54+
if num_local_experts > global_num_experts:
55+
pytest.skip(
56+
"num_local_experts cannot be greater than global_num_experts")
57+
58+
device = torch.device("npu")
59+
hf_config = PretrainedConfig(
60+
num_experts_per_tok=top_k_num,
61+
num_experts=global_num_experts,
62+
)
63+
64+
# Instantiate implementations
65+
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
66+
67+
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
68+
69+
# TODO: Find out if this is the correct way to mock the forward context and ep group
70+
# Mock get_forward_context to return an object with moe_comm_method
71+
forward_context._forward_context = SimpleNamespace(
72+
moe_comm_method=all_gather_impl)
73+
# Mock get_ep_group to return a fake group with the specified ep_rank
74+
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
75+
moe_comm_method.get_ep_group = lambda: fake_ep_group
76+
77+
# --- Input Data ---
78+
hidden_states = torch.randn(num_tokens,
79+
hidden_size,
80+
device=device,
81+
dtype=dtype)
82+
topk_ids = torch.randint(0,
83+
global_num_experts, (num_tokens, top_k_num),
84+
device=device,
85+
dtype=torch.int32)
86+
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
87+
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
88+
89+
num_experts = global_num_experts
90+
91+
expert_map = None
92+
if num_local_experts < global_num_experts:
93+
# Create a map where some experts are local and some are not
94+
expert_map = torch.full((global_num_experts, ), -1, device=device)
95+
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
96+
num_local_experts] = torch.arange(num_local_experts,
97+
device=device)
98+
num_experts = num_local_experts
99+
100+
# --- Run Native Implementation (Golden Reference) ---
101+
native_hidden_states_out = hidden_states.clone()
102+
(
103+
native_permuted_hidden,
104+
native_expert_tokens,
105+
_,
106+
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
107+
expert_map, num_experts)
108+
# Simulate MLP output
109+
native_mlp_output = torch.randn_like(native_permuted_hidden)
110+
native_impl._post_process(native_mlp_output, native_hidden_states_out)
111+
112+
# --- Run AllGather Implementation ---
113+
all_gather_hidden_states_out = hidden_states.clone()
114+
(
115+
all_gather_permuted_hidden,
116+
all_gather_expert_tokens,
117+
_,
118+
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
119+
topk_weights, expert_map,
120+
num_experts)
121+
122+
# Use the same simulated MLP output for a fair comparison
123+
all_gather_mlp_output = native_mlp_output.clone()
124+
125+
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
126+
all_gather_hidden_states_out)
127+
128+
# --- Assertions ---
129+
# Define tolerance based on dtype
130+
atol = 1e-3 if dtype == torch.float16 else 1e-2
131+
rtol = 1e-3 if dtype == torch.float16 else 1e-2
132+
133+
# 1. Compare expert_tokens from pre_process
134+
assert torch.allclose(native_expert_tokens.to(
135+
all_gather_expert_tokens.device),
136+
all_gather_expert_tokens,
137+
atol=atol,
138+
rtol=rtol), "Expert tokens do not match."
139+
140+
# 2. Compare permuted_hidden_states from pre_process
141+
num_valid_tokens = native_expert_tokens.sum()
142+
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
143+
all_gather_permuted_hidden.device),
144+
all_gather_permuted_hidden[:num_valid_tokens],
145+
atol=atol,
146+
rtol=rtol), "Permuted hidden states do not match."
147+
148+
# 3. Compare final hidden_states from post_process
149+
assert torch.allclose(native_hidden_states_out.to(
150+
all_gather_hidden_states_out.device),
151+
all_gather_hidden_states_out,
152+
atol=atol,
153+
rtol=rtol), "Final hidden states do not match."

vllm_ascend/ascend_forward_context.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
import torch
77
from vllm.config import VllmConfig
8-
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
8+
from vllm.distributed import (get_dp_group, get_ep_group,
9+
get_tensor_model_parallel_world_size)
910
from vllm.forward_context import get_forward_context, set_forward_context
1011

1112
import vllm_ascend.envs as envs
12-
from vllm_ascend.platform import NPUPlatform
13+
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
1314

1415

1516
class FusedMoEState(Enum):
@@ -54,6 +55,8 @@ def set_ascend_forward_context(
5455
num_tokens_across_dp: Optional[torch.Tensor] = None,
5556
with_prefill: bool = True,
5657
in_profile_run: bool = False,
58+
reserved_mc2_mask: Optional[torch.Tensor] = None,
59+
moe_comm_method: Optional[MoECommMethod] = None,
5760
num_actual_tokens: Optional[int] = None,
5861
):
5962
"""A context manager that stores the current forward context,
@@ -66,6 +69,7 @@ def set_ascend_forward_context(
6669
num_tokens=num_tokens,
6770
num_tokens_across_dp=num_tokens_across_dp):
6871
forward_context = get_forward_context()
72+
forward_context.moe_comm_method = moe_comm_method
6973
forward_context.with_prefill = with_prefill
7074
ep_size = (get_ep_group().world_size if
7175
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -97,16 +101,17 @@ def set_ascend_forward_context(
97101
if num_tokens is not None:
98102
if num_actual_tokens is None:
99103
num_actual_tokens = num_tokens
100-
tp_world_size = get_tp_group().world_size
104+
tp_world_size = get_tensor_model_parallel_world_size()
101105
# NOTE: token num which need to pad to when mc2
102106
forward_context.padded_num_tokens = math.ceil(
103107
max_tokens_across_dp / tp_world_size) * tp_world_size
104108

105-
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
106-
dtype=torch.bool,
107-
device=NPUPlatform.device_type)
108-
mc2_mask[:num_actual_tokens] = True
109-
forward_context.mc2_mask = mc2_mask
109+
if reserved_mc2_mask is not None:
110+
mc2_mask = reserved_mc2_mask[:forward_context.
111+
padded_num_tokens]
112+
mc2_mask[:num_actual_tokens] = True
113+
mc2_mask[num_actual_tokens:] = False
114+
forward_context.mc2_mask = mc2_mask
110115

111116
try:
112117
yield

0 commit comments

Comments
 (0)