@@ -237,9 +237,7 @@ def __qeff_init__(
237237 fusedqk = torch .bmm (per_head_q_up , per_head_k_up )
238238 # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
239239 self .fusedqk = torch .nn .Parameter (fusedqk .detach ().clone ())
240-
241- # self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias)
242- # self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias)
240+ self .kv_a_proj_with_mqa_ckv , self .kv_a_proj_with_mqa_k_pe = self .kv_a_proj_with_mqa .weight .T .split ([self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
243241
244242 def fused_forward (
245243 self ,
@@ -258,10 +256,8 @@ def fused_forward(
258256 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
259257 bsz , q_len , _ = hidden_states .size ()
260258
261- compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
262- # compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states)
263- # k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states)
264- compressed_kv , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
259+ compressed_kv = torch .matmul (hidden_states , self .kv_a_proj_with_mqa_ckv )
260+ k_pe = torch .matmul (hidden_states , self .kv_a_proj_with_mqa_k_pe )
265261 k_pe = k_pe .view (bsz , q_len , 1 , self .qk_rope_head_dim ).transpose (1 , 2 )
266262
267263 q_a_proj_out = self .q_a_layernorm (self .q_a_proj (hidden_states ))
0 commit comments