Skip to content

Commit 3f167fd

Browse files
committed
fix mla output shape bug
Signed-off-by: whx-sjtu <[email protected]>
1 parent 6fad01b commit 3f167fd

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -585,10 +585,13 @@ def forward(
585585
forward_context = get_forward_context()
586586
if kv_cache is None:
587587
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
588+
num_tokens = hidden_states.shape[0]
589+
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.num_layers:
590+
# Simulate all gather to calculate output shape
591+
num_tokens = num_tokens * self.tp_size
588592
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
589593
output_shape = hidden_states.shape
590594
else:
591-
num_tokens = hidden_states.shape[0]
592595
rows = num_tokens // self.tp_size
593596
if num_tokens % self.tp_size:
594597
rows += 1
@@ -659,8 +662,6 @@ def __init__(
659662
quant_config=quant_config,
660663
prefix=f"{prefix}.mlp",
661664
)
662-
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
663-
and model_config.use_mla and self.tp_size > 1
664665
else:
665666
self.mlp = CustomDeepseekV2MLP(
666667
hidden_size=config.hidden_size,
@@ -669,7 +670,6 @@ def __init__(
669670
quant_config=quant_config,
670671
prefix=f"{prefix}.mlp",
671672
)
672-
self.mla_moe_communication = False
673673
self.input_layernorm = RMSNorm(config.hidden_size,
674674
eps=config.rms_norm_eps)
675675
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -689,10 +689,6 @@ def forward(
689689
replace_allreduce: bool = False,
690690
) -> torch.Tensor:
691691
# Self Attention
692-
if attn_metadata is not None and attn_metadata.num_decodes > 0:
693-
mla_moe_communication = self.mla_moe_communication and replace_allreduce
694-
else:
695-
mla_moe_communication = False
696692
if residual is None:
697693
residual = hidden_states
698694
hidden_states = self.input_layernorm(hidden_states)
@@ -704,9 +700,6 @@ def forward(
704700
# to save npu memory because they're no longer used.
705701
dispose_tensor(previous_hidden_states)
706702
dispose_tensor(previous_residual)
707-
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
708-
hidden_states = tensor_model_parallel_all_gather(hidden_states,
709-
dim=0)
710703

711704
hidden_states = self.self_attn(
712705
positions=positions,
@@ -715,13 +708,6 @@ def forward(
715708
attn_metadata=attn_metadata,
716709
)
717710

718-
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
719-
0]:
720-
chunk_hidden_states = torch.tensor_split(residual,
721-
self.tp_size,
722-
dim=0)
723-
residual = chunk_hidden_states[self.tp_rank]
724-
725711
if hidden_states.dtype == torch.float16:
726712
# Fix FP16 overflow
727713
# We scale both hidden_states and residual before
@@ -750,8 +736,7 @@ def forward(
750736

751737
if isinstance(self.mlp, CustomDeepseekV2MoE):
752738
hidden_states = self.mlp(hidden_states,
753-
attn_metadata,
754-
replace_allreduce=mla_moe_communication)
739+
attn_metadata)
755740
else:
756741
hidden_states = self.mlp(hidden_states)
757742

@@ -764,10 +749,6 @@ def forward(
764749
# The scaling of DeepseekV2MOE output would be done in the forward
765750
# of DeepseekV2MOE
766751
hidden_states *= 1. / self.routed_scaling_factor
767-
if mla_moe_communication and self.layer_idx == self.layers - 1:
768-
hidden_states = tensor_model_parallel_all_gather(hidden_states,
769-
dim=0)
770-
residual = tensor_model_parallel_all_gather(residual, dim=0)
771752

772753
# for last layer of main model and mtp layer.
773754
if self.enable_shared_expert_dp and self.layer_idx >= (

0 commit comments

Comments
 (0)