@@ -437,6 +437,7 @@ def __init__(
437
437
cache_config : Optional [CacheConfig ] = None ,
438
438
quant_config : Optional [QuantizationConfig ] = None ,
439
439
prefix : str = "" ,
440
+ decoder_layer = None ,
440
441
) -> None :
441
442
nn .Module .__init__ (self )
442
443
self .hidden_size = hidden_size
@@ -473,50 +474,44 @@ def __init__(
473
474
self .q_lora_rank ,
474
475
bias = False ,
475
476
quant_config = quant_config ,
476
- prefix = f"{ prefix } .q_a_proj" ,
477
- return_bias = False )
477
+ prefix = f"{ prefix } .q_a_proj" )
478
478
self .q_a_layernorm = RMSNorm (self .q_lora_rank ,
479
479
eps = config .rms_norm_eps )
480
480
self .q_b_proj = ColumnParallelLinear (q_lora_rank ,
481
481
self .num_heads *
482
482
self .qk_head_dim ,
483
483
bias = False ,
484
484
quant_config = quant_config ,
485
- prefix = f"{ prefix } .q_b_proj" ,
486
- return_bias = False )
485
+ prefix = f"{ prefix } .q_b_proj" )
487
486
else :
488
487
self .q_proj = ColumnParallelLinear (self .hidden_size ,
489
488
self .num_heads *
490
489
self .qk_head_dim ,
491
490
bias = False ,
492
491
quant_config = quant_config ,
493
- prefix = f"{ prefix } .q_proj" ,
494
- return_bias = False )
492
+ prefix = f"{ prefix } .q_proj" )
495
493
496
494
self .kv_a_proj_with_mqa = ReplicatedLinear (
497
495
self .hidden_size ,
498
496
self .kv_lora_rank + self .qk_rope_head_dim ,
499
497
bias = False ,
500
498
quant_config = quant_config ,
501
- prefix = f"{ prefix } .kv_a_proj_with_mqa" ,
502
- return_bias = False )
499
+ prefix = f"{ prefix } .kv_a_proj_with_mqa" )
503
500
self .kv_a_layernorm = RMSNorm (self .kv_lora_rank ,
504
501
eps = config .rms_norm_eps )
505
502
self .kv_b_proj = ColumnParallelLinear (
506
503
self .kv_lora_rank ,
507
504
self .num_heads * (self .qk_nope_head_dim + self .v_head_dim ),
508
505
bias = False ,
509
506
quant_config = quant_config ,
510
- prefix = f"{ prefix } .kv_b_proj" ,
511
- return_bias = False )
507
+ prefix = f"{ prefix } .kv_b_proj" )
512
508
513
509
if oproj_tp_enable ():
514
510
self .o_proj = RowParallelLinear (self .num_heads * self .v_head_dim ,
515
511
self .hidden_size ,
516
512
bias = False ,
517
513
quant_config = quant_config ,
518
- prefix = f"{ prefix } .o_proj" ,
519
- return_bias = False )
514
+ prefix = f"{ prefix } .o_proj" )
520
515
elif (config .n_routed_experts is not None
521
516
and self .debug_layer_idx >= config .first_k_dense_replace
522
517
and self .debug_layer_idx % config .moe_layer_freq == 0
@@ -527,16 +522,14 @@ def __init__(
527
522
self .hidden_size ,
528
523
bias = False ,
529
524
quant_config = quant_config ,
530
- prefix = f"{ prefix } .o_proj" ,
531
- return_bias = False )
525
+ prefix = f"{ prefix } .o_proj" )
532
526
else :
533
527
self .o_proj = TorchairDeepseekV2RowParallelLinear (
534
528
self .num_heads * self .v_head_dim ,
535
529
self .hidden_size ,
536
530
bias = False ,
537
531
quant_config = quant_config ,
538
- prefix = f"{ prefix } .o_proj" ,
539
- return_bias = False )
532
+ prefix = f"{ prefix } .o_proj" )
540
533
541
534
if rope_scaling :
542
535
rope_scaling ["rope_type" ] = 'deepseek_yarn'
@@ -592,7 +585,7 @@ def forward(
592
585
enable_multistream_mla = (self .enable_multistream_mla
593
586
and attn_metadata is not None
594
587
and not forward_context .with_prefill
595
- and not attn_metadata .is_prefill )
588
+ and attn_metadata .num_decodes > 0 )
596
589
forward_kwargs = {"enable_multistream_mla" : enable_multistream_mla }
597
590
if self .q_lora_rank is not None :
598
591
npu_prefetch (self .q_a_proj .weight ,
@@ -950,8 +943,11 @@ def forward(
950
943
replace_allreduce : bool = False ,
951
944
) -> torch .Tensor :
952
945
# Self Attention
953
- if attn_metadata is not None and not attn_metadata .is_prefill :
954
- mla_moe_communication = self .mla_moe_communication and replace_allreduce
946
+ if attn_metadata is not None :
947
+ decoding_condition_met = (
948
+ not attn_metadata .is_prefill if self .use_sfa else
949
+ attn_metadata .num_decodes > 0 if self .use_mla else False )
950
+ mla_moe_communication = decoding_condition_met and self .mla_moe_communication and replace_allreduce
955
951
else :
956
952
mla_moe_communication = False
957
953
0 commit comments