@@ -528,17 +528,15 @@ def __init__(
528
528
bias = False ,
529
529
quant_config = quant_config ,
530
530
prefix = f"{ prefix } .o_proj" ,
531
- return_bias = False
532
- )
531
+ return_bias = False )
533
532
else :
534
533
self .o_proj = TorchairDeepseekV2RowParallelLinear (
535
534
self .num_heads * self .v_head_dim ,
536
535
self .hidden_size ,
537
536
bias = False ,
538
537
quant_config = quant_config ,
539
538
prefix = f"{ prefix } .o_proj" ,
540
- return_bias = False
541
- )
539
+ return_bias = False )
542
540
543
541
if rope_scaling :
544
542
rope_scaling ["rope_type" ] = 'deepseek_yarn'
@@ -738,10 +736,10 @@ def __init__(
738
736
return_bias = False ,
739
737
)
740
738
if (config .n_routed_experts is not None
741
- and self .debug_layer_idx >= config .first_k_dense_replace
742
- and self .debug_layer_idx % config .moe_layer_freq == 0
743
- and (ascend_config .multistream_overlap_shared_expert
744
- or self .enable_shared_expert_dp )):
739
+ and self .debug_layer_idx >= config .first_k_dense_replace
740
+ and self .debug_layer_idx % config .moe_layer_freq == 0
741
+ and (ascend_config .multistream_overlap_shared_expert
742
+ or self .enable_shared_expert_dp )):
745
743
self .o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce (
746
744
self .num_heads * self .v_head_dim ,
747
745
self .hidden_size ,
@@ -827,8 +825,10 @@ def forward(
827
825
attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
828
826
forward_context = get_forward_context ()
829
827
if not self .torchair_graph_enabled :
830
- if forward_context .attn_metadata is not None and isinstance (forward_context .attn_metadata , dict ):
831
- attn_metadata = next (iter (forward_context .attn_metadata .values ()), None )
828
+ if forward_context .attn_metadata is not None and isinstance (
829
+ forward_context .attn_metadata , dict ):
830
+ attn_metadata = next (
831
+ iter (forward_context .attn_metadata .values ()), None )
832
832
else :
833
833
attn_metadata = forward_context .attn_metadata
834
834
if kv_cache is None :
@@ -843,7 +843,9 @@ def forward(
843
843
# need_gather_q_kv = True
844
844
if not self .enable_shared_expert_dp or self .debug_layer_idx != self .first_k_dense_replace :
845
845
output_shape = hidden_states .shape
846
- if self .enable_shared_expert_dp and (self .debug_layer_idx == self .first_k_dense_replace or self .debug_layer_idx == self .layers ):
846
+ if self .enable_shared_expert_dp and (
847
+ self .debug_layer_idx == self .first_k_dense_replace
848
+ or self .debug_layer_idx == self .layers ):
847
849
rows = num_tokens // self .tp_size
848
850
if num_tokens % self .tp_size :
849
851
rows += 1
0 commit comments