Skip to content

Commit c611291

Browse files
lbk-syslibaokui
andauthored
【main】SP For Qwen3 MoE (#2209)
### What this PR does / why we need it? Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2, replacing AllReduce with Reduce-Scatter and AllGather achieves computational benefits in norm operations while saving one AllGather communication. This feature is enabled during the P-phase and delivers notable gains in long-sequence scenarios (e.g., 16k–25k), with performance improvements reaching 5%–10%. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ``` compilation_config={ "pass_config":{ "enable_sequence_parallelism": True } }, enable_expert_parallel=True, ``` - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@9edd1db --------- Signed-off-by: libaokui <[email protected]> Co-authored-by: libaokui <[email protected]>
1 parent 57b9f02 commit c611291

File tree

11 files changed

+299
-11
lines changed

11 files changed

+299
-11
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ jobs:
284284
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
285285
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
286286
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
287+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
287288
pytest -sv tests/e2e/multicard/test_data_parallel.py
288289
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
289290
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,27 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC():
234234
},
235235
) as vllm_model:
236236
vllm_model.generate_greedy(prompts, max_tokens)
237+
238+
239+
def test_sp_for_qwen3_moe() -> None:
240+
example_prompts = [
241+
"Hello, my name is",
242+
]
243+
sampling_params = SamplingParams(max_tokens=5,
244+
temperature=0.0,
245+
top_k=50,
246+
top_p=0.9)
247+
248+
with VllmRunner(
249+
snapshot_download("Qwen/Qwen3-30B-A3B"),
250+
dtype="auto",
251+
tensor_parallel_size=2,
252+
distributed_executor_backend="mp",
253+
compilation_config={
254+
"pass_config": {
255+
"enable_sequence_parallelism": True
256+
}
257+
},
258+
enable_expert_parallel=True,
259+
) as vllm_model:
260+
vllm_model.generate(example_prompts, sampling_params)

tests/ut/test_platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def setUp(self):
2626
self.mock_vllm_config.cache_config = MagicMock()
2727
self.mock_vllm_config.scheduler_config = MagicMock()
2828
self.mock_vllm_config.speculative_config = None
29+
self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
2930

3031
self.mock_ascend_config = MagicMock()
3132
self.mock_ascend_config.torchair_graph_config.enabled = False

vllm_ascend/attention/attention_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class AscendMetadata:
151151
slot_mapping: torch.Tensor = None
152152

153153
enable_dbo_across_dp: bool = False
154+
is_only_prefill: bool = False
154155

155156

156157
class AscendAttentionMetadataBuilder:
@@ -166,7 +167,8 @@ def build(self,
166167
num_reqs,
167168
num_actual_tokens,
168169
max_query_len,
169-
enable_dbo_across_dp: bool = False):
170+
enable_dbo_across_dp: bool = False,
171+
is_only_prefill: bool = False):
170172

171173
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
172174
)
@@ -203,7 +205,8 @@ def build(self,
203205
slot_mapping=slot_mapping,
204206
attn_mask=attn_mask,
205207
attn_state=attn_state,
206-
enable_dbo_across_dp=enable_dbo_across_dp)
208+
enable_dbo_across_dp=enable_dbo_across_dp,
209+
is_only_prefill=is_only_prefill)
207210
return attn_metadata
208211

209212

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ def build(self,
223223
num_actual_tokens,
224224
max_query_len,
225225
graph_pad_size: int = -1,
226-
enable_dbo_across_dp: bool = False):
226+
enable_dbo_across_dp: bool = False,
227+
*args,
228+
**kwargs):
227229

228230
device = self.runner.device
229231

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def build(
384384
graph_pad_size: int = -1,
385385
query_start_loc: torch.Tensor = None,
386386
enable_dbo_across_dp: bool = False,
387+
*args,
388+
**kwargs,
387389
) -> AscendMLAMetadata:
388390
assert self._num_decodes + self._num_prefills == num_reqs
389391

vllm_ascend/models/qwen3_moe.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
# limitations under the License.
1717
# Adapted from vllm/model_executor/models/qwen3_moe.py
1818
# This file is a part of the vllm-ascend project.
19-
from typing import Optional
19+
20+
from typing import Optional, Union
2021

2122
import torch
2223
from torch import nn
2324
from transformers import PretrainedConfig
2425
from vllm.compilation.decorators import support_torch_compile
2526
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
26-
from vllm.distributed import get_tensor_model_parallel_world_size
27+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2728
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
2829
get_tp_group)
2930
from vllm.forward_context import get_forward_context
@@ -44,8 +45,11 @@
4445
from vllm.model_executor.models.utils import (
4546
PPMissingLayer, extract_layer_index,
4647
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
48+
from vllm.sequence import IntermediateTensors
4749

4850
from vllm_ascend.ops.fused_moe import AscendFusedMoE
51+
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
52+
init_metadata_for_sp)
4953

5054

5155
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -96,6 +100,7 @@ def forward(
96100
self,
97101
hidden_states,
98102
attn_metadata=None,
103+
_metadata_for_padding: Optional[MetadataForPadding] = None,
99104
):
100105
if attn_metadata is None:
101106
attn_metadata = get_forward_context().attn_metadata
@@ -114,6 +119,7 @@ def forward(
114119
top_k=self.top_k,
115120
enable_force_load_balance=enable_force_load_balance,
116121
shared_experts=None,
122+
_metadata_for_padding=_metadata_for_padding,
117123
)
118124

119125
return hidden_states
@@ -155,14 +161,14 @@ def __init__(
155161
layer_idx = extract_layer_index(prefix)
156162
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
157163
config.mlp_only_layers)
158-
use_aclgraph = (vllm_config is not None
159-
and vllm_config.compilation_config.level
160-
== CompilationLevel.PIECEWISE
161-
and not vllm_config.model_config.enforce_eager)
164+
self.use_aclgraph = (vllm_config is not None
165+
and vllm_config.compilation_config.level
166+
== CompilationLevel.PIECEWISE
167+
and not vllm_config.model_config.enforce_eager)
162168
if (layer_idx not in mlp_only_layers) and (
163169
config.num_experts > 0 and
164170
(layer_idx + 1) % config.decoder_sparse_step == 0):
165-
if not use_aclgraph:
171+
if not self.use_aclgraph:
166172
# FIXME: custom sparse moe block doesn't work with aclgraph.
167173
self.mlp = CustomSparseMoeBlock(config=config,
168174
quant_config=quant_config,
@@ -182,6 +188,60 @@ def __init__(
182188
self.post_attention_layernorm = RMSNorm(config.hidden_size,
183189
eps=config.rms_norm_eps)
184190

191+
self.enable_sequence_parallelism = (
192+
vllm_config.compilation_config.pass_config.
193+
enable_sequence_parallelism if vllm_config is not None else False)
194+
195+
def forward(
196+
self,
197+
positions: torch.Tensor,
198+
hidden_states: torch.Tensor,
199+
residual: Optional[torch.Tensor],
200+
_metadata_for_padding: Optional[MetadataForPadding] = None,
201+
) -> torch.Tensor:
202+
203+
# To prevent precision issues during the decoder phase when only prefilling enables SP
204+
if not self.enable_sequence_parallelism:
205+
self.self_attn.o_proj.reduce_results = True
206+
else:
207+
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208+
209+
# Self Attention
210+
if residual is None:
211+
residual = hidden_states
212+
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
213+
residual = _metadata_for_padding.padding_slice(residual)
214+
215+
hidden_states = self.input_layernorm(hidden_states)
216+
else:
217+
hidden_states, residual = self.input_layernorm(
218+
hidden_states, residual)
219+
220+
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
221+
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
222+
hidden_states)
223+
224+
hidden_states = self.self_attn(
225+
positions=positions,
226+
hidden_states=hidden_states,
227+
)
228+
229+
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
230+
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
231+
hidden_states)
232+
233+
# Fully Connected
234+
hidden_states, residual = self.post_attention_layernorm(
235+
hidden_states, residual)
236+
237+
if not self.use_aclgraph:
238+
hidden_states = self.mlp(
239+
hidden_states, _metadata_for_padding=_metadata_for_padding)
240+
else:
241+
hidden_states = self.mlp(hidden_states)
242+
243+
return hidden_states, residual
244+
185245

186246
@support_torch_compile
187247
class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -216,6 +276,45 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
216276
make_empty_intermediate_tensors_factory(
217277
["hidden_states", "residual"], config.hidden_size))
218278

279+
def forward(
280+
self,
281+
input_ids: torch.Tensor,
282+
positions: torch.Tensor,
283+
intermediate_tensors: Optional[IntermediateTensors] = None,
284+
inputs_embeds: Optional[torch.Tensor] = None,
285+
_metadata_for_padding: Optional[MetadataForPadding] = None,
286+
) -> Union[torch.Tensor, IntermediateTensors]:
287+
if get_pp_group().is_first_rank:
288+
if inputs_embeds is not None:
289+
hidden_states = inputs_embeds
290+
else:
291+
hidden_states = self.get_input_embeddings(input_ids)
292+
residual = None
293+
else:
294+
assert intermediate_tensors is not None
295+
hidden_states = intermediate_tensors["hidden_states"]
296+
residual = intermediate_tensors["residual"]
297+
for i in range(self.start_layer, self.end_layer):
298+
layer = self.layers[i]
299+
hidden_states, residual = layer(
300+
positions,
301+
hidden_states,
302+
residual,
303+
_metadata_for_padding=_metadata_for_padding)
304+
if not get_pp_group().is_last_rank:
305+
return IntermediateTensors({
306+
"hidden_states": hidden_states,
307+
"residual": residual
308+
})
309+
310+
hidden_states, _ = self.norm(hidden_states, residual)
311+
312+
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
313+
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
314+
hidden_states)
315+
316+
return hidden_states
317+
219318

220319
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
221320
packed_modules_mapping = {
@@ -253,6 +352,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
253352
self.make_empty_intermediate_tensors = (
254353
self.model.make_empty_intermediate_tensors)
255354

355+
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
256356
# Set MoE hyperparameters
257357
self.expert_weights: list[torch.Tensor] = []
258358

@@ -273,3 +373,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
273373
self.num_moe_layers = len(self.moe_layers)
274374
self.num_expert_groups = 1
275375
self.num_shared_experts = 0
376+
377+
def forward(
378+
self,
379+
input_ids: torch.Tensor,
380+
positions: torch.Tensor,
381+
intermediate_tensors: Optional[IntermediateTensors] = None,
382+
inputs_embeds: Optional[torch.Tensor] = None,
383+
) -> Union[torch.Tensor, IntermediateTensors]:
384+
_metadata_for_padding = init_metadata_for_sp(
385+
input_ids, self.enable_sequence_parallelism)
386+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
387+
inputs_embeds, _metadata_for_padding)
388+
return hidden_states

vllm_ascend/ops/fused_moe.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4848
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4949
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
50+
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
5051
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
5152
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5253
get_all_reduce_merge_state,
@@ -1347,7 +1348,8 @@ def forward(self,
13471348
top_k: Optional[int] = None,
13481349
shared_experts: Optional[Any] = None,
13491350
gate=None,
1350-
replace_allreduce: bool = False):
1351+
replace_allreduce: bool = False,
1352+
_metadata_for_padding: Optional[MetadataForPadding] = None):
13511353

13521354
assert self.quant_method is not None
13531355

@@ -1381,7 +1383,17 @@ def forward(self,
13811383
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
13821384
shared_hidden_states = shared_experts(hidden_states)
13831385

1386+
mc2_mask = forward_context.mc2_mask
1387+
1388+
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
13841389
tp_size = get_tensor_model_parallel_world_size()
1390+
if enable_sp:
1391+
tp_rank = get_tensor_model_parallel_rank()
1392+
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
1393+
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
1394+
mc2_mask = chunk_mc2_mask[tp_rank]
1395+
replace_allreduce = True
1396+
13851397
if (fused_moe_state not in [
13861398
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
13871399
FusedMoEState.NaiveMulticast

0 commit comments

Comments
 (0)