Skip to content

Commit c16a0c7

Browse files
quic-mamtamamtsing
authored andcommitted
Split kv_a_proj_with_mqa weights to get ckv and k_pe
Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent a47fff0 commit c16a0c7

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,9 @@ def __qeff_init__(
237237
fusedqk = torch.bmm(per_head_q_up, per_head_k_up)
238238
# self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
239239
self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone())
240-
241-
# self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias)
242-
# self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias)
240+
kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
241+
self.kv_a_proj_with_mqa_ckv = torch.nn.Parameter(kv_a_proj_with_mqa_ckv.detach().clone())
242+
self.kv_a_proj_with_mqa_k_pe = torch.nn.Parameter(kv_a_proj_with_mqa_k_pe.detach().clone())
243243

244244
def fused_forward(
245245
self,
@@ -258,10 +258,8 @@ def fused_forward(
258258
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
259259
bsz, q_len, _ = hidden_states.size()
260260

261-
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
262-
# compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states)
263-
# k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states)
264-
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
261+
compressed_kv = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_ckv)
262+
k_pe = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_k_pe)
265263
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
266264

267265
q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states))

0 commit comments

Comments
 (0)