Skip to content

Commit b8b3326

Browse files
authored
Split kv_a_proj_with_mqa weights to get ckv and k_pe
Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com>
1 parent a47fff0 commit b8b3326

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,7 @@ def __qeff_init__(
237237
fusedqk = torch.bmm(per_head_q_up, per_head_k_up)
238238
# self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
239239
self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone())
240-
241-
# self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias)
242-
# self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias)
240+
self.kv_a_proj_with_mqa_ckv, self.kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
243241

244242
def fused_forward(
245243
self,
@@ -258,10 +256,8 @@ def fused_forward(
258256
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
259257
bsz, q_len, _ = hidden_states.size()
260258

261-
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
262-
# compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states)
263-
# k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states)
264-
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
259+
compressed_kv = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_ckv)
260+
k_pe = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_k_pe)
265261
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
266262

267263
q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states))

0 commit comments

Comments
 (0)