Skip to content

Commit d0db29f

Browse files
linfeng-yuanwangxiyuan
authored andcommitted
fix broken deepseek_r1
Signed-off-by: linfeng-yuan <[email protected]>
1 parent 5f15a01 commit d0db29f

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def __init__(
437437
cache_config: Optional[CacheConfig] = None,
438438
quant_config: Optional[QuantizationConfig] = None,
439439
prefix: str = "",
440+
decoder_layer=None,
440441
) -> None:
441442
nn.Module.__init__(self)
442443
self.hidden_size = hidden_size
@@ -473,50 +474,44 @@ def __init__(
473474
self.q_lora_rank,
474475
bias=False,
475476
quant_config=quant_config,
476-
prefix=f"{prefix}.q_a_proj",
477-
return_bias=False)
477+
prefix=f"{prefix}.q_a_proj")
478478
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
479479
eps=config.rms_norm_eps)
480480
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
481481
self.num_heads *
482482
self.qk_head_dim,
483483
bias=False,
484484
quant_config=quant_config,
485-
prefix=f"{prefix}.q_b_proj",
486-
return_bias=False)
485+
prefix=f"{prefix}.q_b_proj")
487486
else:
488487
self.q_proj = ColumnParallelLinear(self.hidden_size,
489488
self.num_heads *
490489
self.qk_head_dim,
491490
bias=False,
492491
quant_config=quant_config,
493-
prefix=f"{prefix}.q_proj",
494-
return_bias=False)
492+
prefix=f"{prefix}.q_proj")
495493

496494
self.kv_a_proj_with_mqa = ReplicatedLinear(
497495
self.hidden_size,
498496
self.kv_lora_rank + self.qk_rope_head_dim,
499497
bias=False,
500498
quant_config=quant_config,
501-
prefix=f"{prefix}.kv_a_proj_with_mqa",
502-
return_bias=False)
499+
prefix=f"{prefix}.kv_a_proj_with_mqa")
503500
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
504501
eps=config.rms_norm_eps)
505502
self.kv_b_proj = ColumnParallelLinear(
506503
self.kv_lora_rank,
507504
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
508505
bias=False,
509506
quant_config=quant_config,
510-
prefix=f"{prefix}.kv_b_proj",
511-
return_bias=False)
507+
prefix=f"{prefix}.kv_b_proj")
512508

513509
if oproj_tp_enable():
514510
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
515511
self.hidden_size,
516512
bias=False,
517513
quant_config=quant_config,
518-
prefix=f"{prefix}.o_proj",
519-
return_bias=False)
514+
prefix=f"{prefix}.o_proj")
520515
elif (config.n_routed_experts is not None
521516
and self.debug_layer_idx >= config.first_k_dense_replace
522517
and self.debug_layer_idx % config.moe_layer_freq == 0
@@ -527,16 +522,14 @@ def __init__(
527522
self.hidden_size,
528523
bias=False,
529524
quant_config=quant_config,
530-
prefix=f"{prefix}.o_proj",
531-
return_bias=False)
525+
prefix=f"{prefix}.o_proj")
532526
else:
533527
self.o_proj = TorchairDeepseekV2RowParallelLinear(
534528
self.num_heads * self.v_head_dim,
535529
self.hidden_size,
536530
bias=False,
537531
quant_config=quant_config,
538-
prefix=f"{prefix}.o_proj",
539-
return_bias=False)
532+
prefix=f"{prefix}.o_proj")
540533

541534
if rope_scaling:
542535
rope_scaling["rope_type"] = 'deepseek_yarn'
@@ -592,7 +585,7 @@ def forward(
592585
enable_multistream_mla = (self.enable_multistream_mla
593586
and attn_metadata is not None
594587
and not forward_context.with_prefill
595-
and not attn_metadata.is_prefill)
588+
and attn_metadata.num_decodes > 0)
596589
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
597590
if self.q_lora_rank is not None:
598591
npu_prefetch(self.q_a_proj.weight,
@@ -950,8 +943,11 @@ def forward(
950943
replace_allreduce: bool = False,
951944
) -> torch.Tensor:
952945
# Self Attention
953-
if attn_metadata is not None and not attn_metadata.is_prefill:
954-
mla_moe_communication = self.mla_moe_communication and replace_allreduce
946+
if attn_metadata is not None:
947+
decoding_condition_met = (
948+
not attn_metadata.is_prefill if self.use_sfa else
949+
attn_metadata.num_decodes > 0 if self.use_mla else False)
950+
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
955951
else:
956952
mla_moe_communication = False
957953

0 commit comments

Comments
 (0)