@@ -585,10 +585,13 @@ def forward(
585
585
forward_context = get_forward_context ()
586
586
if kv_cache is None :
587
587
kv_cache = self .mla_attn .kv_cache [forward_context .virtual_engine ]
588
+ num_tokens = hidden_states .shape [0 ]
589
+ if self .enable_shared_expert_dp and self .debug_layer_idx > self .first_k_dense_replace and self .debug_layer_idx < self .layers :
590
+ # Simulate all gather to calculate output shape
591
+ num_tokens = num_tokens * self .tp_size
588
592
if not self .enable_shared_expert_dp or self .debug_layer_idx < self .first_k_dense_replace :
589
593
output_shape = hidden_states .shape
590
594
else :
591
- num_tokens = hidden_states .shape [0 ]
592
595
rows = num_tokens // self .tp_size
593
596
if num_tokens % self .tp_size :
594
597
rows += 1
@@ -659,8 +662,6 @@ def __init__(
659
662
quant_config = quant_config ,
660
663
prefix = f"{ prefix } .mlp" ,
661
664
)
662
- self .mla_moe_communication = ascend_config .torchair_graph_config .enable_multistream_moe \
663
- and model_config .use_mla and self .tp_size > 1
664
665
else :
665
666
self .mlp = CustomDeepseekV2MLP (
666
667
hidden_size = config .hidden_size ,
@@ -669,7 +670,6 @@ def __init__(
669
670
quant_config = quant_config ,
670
671
prefix = f"{ prefix } .mlp" ,
671
672
)
672
- self .mla_moe_communication = False
673
673
self .input_layernorm = RMSNorm (config .hidden_size ,
674
674
eps = config .rms_norm_eps )
675
675
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
@@ -689,10 +689,6 @@ def forward(
689
689
replace_allreduce : bool = False ,
690
690
) -> torch .Tensor :
691
691
# Self Attention
692
- if attn_metadata is not None and attn_metadata .num_decodes > 0 :
693
- mla_moe_communication = self .mla_moe_communication and replace_allreduce
694
- else :
695
- mla_moe_communication = False
696
692
if residual is None :
697
693
residual = hidden_states
698
694
hidden_states = self .input_layernorm (hidden_states )
@@ -704,9 +700,6 @@ def forward(
704
700
# to save npu memory because they're no longer used.
705
701
dispose_tensor (previous_hidden_states )
706
702
dispose_tensor (previous_residual )
707
- if mla_moe_communication and self .layer_idx > self .first_k_dense_replace :
708
- hidden_states = tensor_model_parallel_all_gather (hidden_states ,
709
- dim = 0 )
710
703
711
704
hidden_states = self .self_attn (
712
705
positions = positions ,
@@ -715,13 +708,6 @@ def forward(
715
708
attn_metadata = attn_metadata ,
716
709
)
717
710
718
- if mla_moe_communication and residual .shape [0 ] != hidden_states .shape [
719
- 0 ]:
720
- chunk_hidden_states = torch .tensor_split (residual ,
721
- self .tp_size ,
722
- dim = 0 )
723
- residual = chunk_hidden_states [self .tp_rank ]
724
-
725
711
if hidden_states .dtype == torch .float16 :
726
712
# Fix FP16 overflow
727
713
# We scale both hidden_states and residual before
@@ -750,8 +736,7 @@ def forward(
750
736
751
737
if isinstance (self .mlp , CustomDeepseekV2MoE ):
752
738
hidden_states = self .mlp (hidden_states ,
753
- attn_metadata ,
754
- replace_allreduce = mla_moe_communication )
739
+ attn_metadata )
755
740
else :
756
741
hidden_states = self .mlp (hidden_states )
757
742
@@ -764,10 +749,6 @@ def forward(
764
749
# The scaling of DeepseekV2MOE output would be done in the forward
765
750
# of DeepseekV2MOE
766
751
hidden_states *= 1. / self .routed_scaling_factor
767
- if mla_moe_communication and self .layer_idx == self .layers - 1 :
768
- hidden_states = tensor_model_parallel_all_gather (hidden_states ,
769
- dim = 0 )
770
- residual = tensor_model_parallel_all_gather (residual , dim = 0 )
771
752
772
753
# for last layer of main model and mtp layer.
773
754
if self .enable_shared_expert_dp and self .layer_idx >= (
0 commit comments