Skip to content

Commit dc585f1

Browse files
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)
### What this PR does / why we need it? 1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution. 2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding. 3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill. ### How was this patch tested? Adding ut case in `tests/ut/attention/test_mla_v1.py` #### How to run use parameter `--additional_config='{"enable_shared_expert_dp": true}'` ##### a.How to run eager mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}' ##### b.How to run graph mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@9edd1db --------- Signed-off-by: Wang Kunpeng <[email protected]> Signed-off-by: SlightwindSec <[email protected]> Co-authored-by: SlightwindSec <[email protected]>
1 parent 8181790 commit dc585f1

File tree

6 files changed

+169
-37
lines changed

6 files changed

+169
-37
lines changed

docs/source/user_guide/configuration/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
3232
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
3333
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
3434
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
35+
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
3536

3637
The details of each config option are as follows:
3738

tests/ut/attention/test_mla_v1.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,40 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
691691
self.assertEqual(result.shape[2], self.impl.v_head_dim)
692692
mock_up_proj.assert_called_once()
693693
mock_page_attention_mla.assert_called_once()
694+
695+
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
696+
@patch("torch_npu._npu_reshape_and_cache")
697+
def test_forward_without_graph(self, _, mock_forward_prefill):
698+
self.impl.running_in_graph = False
699+
self.impl.torchair_graph_enabled = False
700+
701+
num_tokens = 100
702+
num_blocks = 256
703+
block_size = 4
704+
rotary_emb_return_value = (torch.randn(num_tokens, 16,
705+
self.impl.kv_lora_rank),
706+
torch.randn(0, 1, self.impl.kv_lora_rank))
707+
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
708+
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
709+
1, num_blocks, 128)
710+
711+
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
712+
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
713+
self.impl.kv_lora_rank)
714+
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
715+
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
716+
self.impl.kv_lora_rank),
717+
torch.randn(num_blocks, block_size, self.impl.num_heads,
718+
self.impl.qk_rope_head_dim))
719+
output = torch.randn(num_tokens, self.impl.num_heads,
720+
self.impl.v_head_dim)
721+
722+
metadata = MagicMock()
723+
metadata.num_decodes = 0
724+
metadata.num_prefills = num_tokens
725+
mock_forward_prefill.return_value = torch.randn(
726+
0, self.impl.num_heads * self.impl.v_head_dim)
727+
result = self.impl.forward(None, hidden_states_or_q_c,
728+
hidden_states_or_kv_c_normed, k_pe,
729+
kv_cache, metadata, output, False)
730+
self.assertEqual(result.shape[0], num_tokens)

vllm_ascend/ascend_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(self, vllm_config):
4747
self.expert_map_path = additional_config.get("expert_map_path", None)
4848
self.chunked_prefill_for_mla = additional_config.get(
4949
"chunked_prefill_for_mla", False)
50+
self.enable_shared_expert_dp = additional_config.get(
51+
"enable_shared_expert_dp", True
52+
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
5053

5154

5255
class TorchairGraphConfig:
@@ -166,6 +169,10 @@ def check_ascend_config(vllm_config, enforce_eager):
166169
raise NotImplementedError(
167170
"Torchair graph mode only works with following model types:"
168171
f"{TORCHAIR_MODEL_LIST}.")
172+
if ascend_config.enable_shared_expert_dp:
173+
logger.warning(
174+
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
175+
"it has been disabled automatically.")
169176
# aclgraph case
170177
else:
171178
# aclgraph doesn't work with deepseek model and only qwen model is well tested.

vllm_ascend/attention/mla_v1.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def __init__(
621621
ascend_config = get_ascend_config()
622622
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
623623
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
624+
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
624625

625626
# Adapt torch air graph mode with spec decoding.
626627
speculative_config = get_current_vllm_config().speculative_config
@@ -635,6 +636,8 @@ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
635636
x = torch.bmm(x, self.W_UV)
636637
# Convert from (N, B, V) to (B, N * V)
637638
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
639+
if hasattr(self, "running_in_graph") and not self.running_in_graph:
640+
return x
638641
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
639642
npu_prefetch(self.o_proj.weight,
640643
x,
@@ -905,14 +908,7 @@ def _forward_prefill(
905908
] and not ascend_config.chunked_prefill_for_mla:
906909
attn_output = attn_output_torch
907910

908-
current_ms_metadata = get_multistream_comm_context()
909-
if current_ms_metadata is None:
910-
return self.o_proj(attn_output, is_prefill=True)[0]
911-
else:
912-
current_ms_metadata.before_comm_event.record()
913-
with torch.npu.stream(current_ms_metadata.comm_stream):
914-
current_ms_metadata.before_comm_event.wait()
915-
return self.o_proj(attn_output, is_prefill=True)[0]
911+
return attn_output
916912

917913
def exec_kv(
918914
self,
@@ -1249,6 +1245,12 @@ def forward(
12491245
key_cache=kv_cache[0],
12501246
value_cache=kv_cache[1],
12511247
slot_indices=attn_metadata.slot_mapping)
1248+
if not self.running_in_graph:
1249+
o_proj_input_shape = (num_actual_toks,
1250+
self.num_heads * self.v_head_dim)
1251+
o_proj_input = torch.empty(o_proj_input_shape,
1252+
dtype=hidden_states_or_q_c.dtype,
1253+
device=hidden_states_or_q_c.device)
12521254
if has_prefill:
12531255
# FIX: aicore move should be also placed on the comm stream in dbo,
12541256
# otherwise it may affect the accuracy
@@ -1259,11 +1261,12 @@ def forward(
12591261
attn_metadata)
12601262
current_ms_metadata = get_multistream_comm_context()
12611263
if current_ms_metadata is not None:
1264+
current_ms_metadata.before_comm_event.record()
12621265
with torch.npu.stream(current_ms_metadata.comm_stream):
1263-
output[num_decode_tokens:] = output_prefill
1264-
current_ms_metadata.after_comm_event.record()
1266+
current_ms_metadata.before_comm_event.wait()
1267+
o_proj_input[num_decode_tokens:] = output_prefill
12651268
else:
1266-
output[num_decode_tokens:] = output_prefill
1269+
o_proj_input[num_decode_tokens:] = output_prefill
12671270

12681271
if has_decode:
12691272
if self.running_in_graph:
@@ -1280,9 +1283,32 @@ def forward(
12801283
current_ms_metadata = get_multistream_comm_context()
12811284
if current_ms_metadata is not None:
12821285
with torch.npu.stream(current_ms_metadata.comm_stream):
1283-
output[:num_decode_tokens] = output_decode
1284-
current_ms_metadata.after_comm_event.record()
1286+
o_proj_input[:num_decode_tokens] = output_decode
12851287
else:
1286-
output[:num_decode_tokens] = output_decode
1288+
o_proj_input[:num_decode_tokens] = output_decode
12871289

1290+
current_ms_metadata = get_multistream_comm_context()
1291+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
1292+
if current_ms_metadata is None:
1293+
npu_prefetch(self.o_proj.weight,
1294+
o_proj_input,
1295+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
1296+
enabled=enable_multistream_mla)
1297+
1298+
output[...] = self.o_proj(
1299+
o_proj_input,
1300+
is_prefill=True,
1301+
is_force_scatter=self.enable_shared_expert_dp)[0]
1302+
else:
1303+
with torch.npu.stream(current_ms_metadata.comm_stream):
1304+
npu_prefetch(self.o_proj.weight,
1305+
o_proj_input,
1306+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
1307+
enabled=enable_multistream_mla)
1308+
output[...] = self.o_proj(
1309+
o_proj_input,
1310+
is_prefill=True,
1311+
is_force_scatter=self.enable_shared_expert_dp)[0]
1312+
current_ms_metadata.after_comm_event.record()
1313+
del o_proj_input
12881314
return output_padded

vllm_ascend/models/deepseek_v2.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
141141
def forward(
142142
self,
143143
input_,
144-
is_prefill=True
144+
is_prefill=True,
145+
is_force_scatter=False
145146
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
146147
if self.input_is_parallel:
147148
input_parallel = input_
@@ -160,7 +161,13 @@ def forward(
160161
input_parallel,
161162
bias=bias_)
162163
if self.reduce_results and self.tp_size > 1:
163-
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
164+
num_tokens = output_parallel.shape[0]
165+
if is_force_scatter and num_tokens % self.tp_size:
166+
output_parallel = nn.functional.pad(
167+
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
168+
if is_force_scatter or (not is_prefill
169+
and output_parallel.shape[0] % self.tp_size
170+
== 0):
164171
output = tensor_model_parallel_reduce_scatter(output_parallel,
165172
dim=0)
166173
else:
@@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
180187
def forward(
181188
self,
182189
input_,
183-
is_prefill=True
190+
is_prefill=True,
191+
is_force_scatter=False
184192
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
185193
if self.input_is_parallel:
186194
input_parallel = input_
@@ -347,13 +355,15 @@ def __init__(
347355
reduce_results = not self.all_reduce_merge
348356
intermediate_size = (config.moe_intermediate_size *
349357
config.n_shared_experts)
358+
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
350359
self.shared_experts = CustomDeepseekV2MLP(
351360
hidden_size=config.hidden_size,
352361
intermediate_size=intermediate_size,
353362
hidden_act=config.hidden_act,
354363
quant_config=quant_config,
355364
reduce_results=reduce_results,
356-
force_replicate=self.enable_multistream_moe,
365+
force_replicate=self.enable_multistream_moe
366+
or enable_shared_expert_dp,
357367
prefix=f"{prefix}.shared_experts",
358368
)
359369
else:
@@ -447,9 +457,11 @@ def __init__(
447457
self.kv_lora_rank = kv_lora_rank
448458

449459
self.num_heads = num_heads
450-
tp_size = get_tensor_model_parallel_world_size()
451-
assert num_heads % tp_size == 0
452-
self.num_local_heads = num_heads // tp_size
460+
self.tp_size = get_tensor_model_parallel_world_size()
461+
assert num_heads % self.tp_size == 0
462+
self.num_local_heads = num_heads // self.tp_size
463+
self.layers = config.num_hidden_layers
464+
self.first_k_dense_replace = config.first_k_dense_replace
453465

454466
self.scaling = self.qk_head_dim**-0.5
455467
self.rope_theta = rope_theta
@@ -462,6 +474,7 @@ def __init__(
462474
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
463475
self.enable_multistream_mla = \
464476
ascend_config.torchair_graph_config.enable_multistream_mla
477+
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
465478

466479
if self.q_lora_rank is not None:
467480
self.q_a_proj = ReplicatedLinear(self.hidden_size,
@@ -501,8 +514,9 @@ def __init__(
501514
prefix=f"{prefix}.kv_b_proj")
502515
if (config.n_routed_experts is not None
503516
and self.debug_layer_idx >= config.first_k_dense_replace
504-
and self.debug_layer_idx % config.moe_layer_freq == 0 and
505-
ascend_config.torchair_graph_config.enable_multistream_moe):
517+
and self.debug_layer_idx % config.moe_layer_freq == 0
518+
and (ascend_config.torchair_graph_config.enable_multistream_moe
519+
or self.enable_shared_expert_dp)):
506520
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
507521
self.num_heads * self.v_head_dim,
508522
self.hidden_size,
@@ -596,13 +610,27 @@ def forward(
596610
output = output.view(-1, output_shape[-1])
597611
return output
598612
else:
599-
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
613+
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
614+
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
615+
hidden_states_or_q_c = get_tp_group().all_gather(
616+
hidden_states_or_q_c, 0)
617+
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
618+
619+
kv_c, k_pe = kv_no_split.split(
600620
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
601621
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
622+
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
623+
output_shape = hidden_states.shape
624+
else:
625+
num_tokens = hidden_states_or_q_c.shape[0]
626+
rows = num_tokens // self.tp_size
627+
if num_tokens % self.tp_size:
628+
rows += 1
629+
output_shape = (rows, hidden_states.shape[1])
602630
return self.mla_attn(hidden_states_or_q_c,
603631
kv_c_normed,
604632
k_pe,
605-
output_shape=hidden_states.shape)
633+
output_shape=output_shape)
606634

607635

608636
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -677,6 +705,8 @@ def __init__(
677705
eps=config.rms_norm_eps)
678706
self.routed_scaling_factor = config.routed_scaling_factor
679707
self.first_k_dense_replace = config.first_k_dense_replace
708+
self.tp_group = get_tp_group().device_group
709+
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
680710

681711
def forward(
682712
self,
@@ -731,6 +761,18 @@ def forward(
731761
# first layer.
732762
residual *= 1. / self.routed_scaling_factor
733763

764+
tp_size = get_tensor_model_parallel_world_size()
765+
if self.enable_shared_expert_dp and (
766+
self.layer_idx == self.first_k_dense_replace
767+
or self.layer_idx == self.layers) and tp_size > 1:
768+
num_tokens, _ = residual.shape
769+
if num_tokens % tp_size:
770+
residual = nn.functional.pad(residual,
771+
(0, 0, 0, -num_tokens % tp_size))
772+
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
773+
tp_rank = get_tensor_model_parallel_rank()
774+
residual = chunk_residual[tp_rank]
775+
734776
# Fully Connected
735777
hidden_states, residual = self.post_attention_layernorm(
736778
hidden_states, residual)
@@ -756,6 +798,22 @@ def forward(
756798
dim=0)
757799
residual = tensor_model_parallel_all_gather(residual, dim=0)
758800

801+
# for last layer of main model and mtp layer.
802+
if self.enable_shared_expert_dp and self.layer_idx >= (
803+
self.layers - 1) and tp_size > 1:
804+
hidden_states = get_tp_group().all_gather(hidden_states, 0)
805+
residual = get_tp_group().all_gather(residual, 0)
806+
807+
attn_metadata = get_forward_context().attn_metadata
808+
if attn_metadata is not None:
809+
num_tokens = attn_metadata.num_actual_tokens
810+
else:
811+
num_tokens = hidden_states.shape[0]
812+
813+
if num_tokens < hidden_states.shape[0]:
814+
hidden_states = hidden_states[:num_tokens]
815+
residual = residual[:num_tokens]
816+
759817
return hidden_states, residual
760818

761819

vllm_ascend/ops/fused_moe.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ def __init__(
12681268
self.enable_multistream_moe = \
12691269
ascend_config.torchair_graph_config.enable_multistream_moe and \
12701270
self.torchair_graph_enabled
1271+
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
12711272

12721273
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12731274
raise ValueError("Only softmax scoring function is supported for "
@@ -1408,22 +1409,24 @@ def forward(self,
14081409
else:
14091410
# TODO: Determine if we can remove the padding
14101411
padding_size = tp_size
1411-
if num_tokens < padding_size:
1412+
if num_tokens < padding_size and not self.enable_shared_expert_dp:
14121413
hidden_states = nn.functional.pad(
14131414
hidden_states, (0, 0, 0, padding_size - num_tokens))
14141415
router_logits = nn.functional.pad(
14151416
router_logits, (0, 0, 0, padding_size - num_tokens))
14161417
if tp_size > 1:
1417-
chunk_hidden_states = torch.tensor_split(hidden_states,
1418-
tp_size,
1419-
dim=0)
1420-
chunk_router_logits = torch.tensor_split(router_logits,
1421-
tp_size,
1422-
dim=0)
1423-
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
14241418
tp_rank = get_tensor_model_parallel_rank()
1425-
hidden_states = chunk_hidden_states[tp_rank]
1426-
router_logits = chunk_router_logits[tp_rank]
1419+
if not self.enable_shared_expert_dp:
1420+
chunk_hidden_states = torch.tensor_split(hidden_states,
1421+
tp_size,
1422+
dim=0)
1423+
chunk_router_logits = torch.tensor_split(router_logits,
1424+
tp_size,
1425+
dim=0)
1426+
hidden_states = chunk_hidden_states[tp_rank]
1427+
router_logits = chunk_router_logits[tp_rank]
1428+
1429+
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
14271430
mc2_mask = chunk_mc2_mask[tp_rank]
14281431

14291432
if self.dp_size > 1:
@@ -1490,7 +1493,7 @@ def forward(self,
14901493
if (fused_moe_state not in [
14911494
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
14921495
FusedMoEState.NaiveMulticast
1493-
] and not replace_allreduce):
1496+
] and not replace_allreduce and not self.enable_shared_expert_dp):
14941497
if tp_size > 1:
14951498
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
14961499
self.tp_group)
@@ -1500,7 +1503,7 @@ def forward(self,
15001503
final_hidden_states = e_hidden_states
15011504
if num_tokens < padding_size:
15021505
final_hidden_states = final_hidden_states[:num_tokens]
1503-
elif self.dp_size > 1:
1506+
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
15041507
if fused_moe_state == FusedMoEState.NaiveMulticast:
15051508
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
15061509
self.dp_rank - 1]

0 commit comments

Comments
 (0)