@@ -237,9 +237,9 @@ def __qeff_init__(
237237 fusedqk = torch .bmm (per_head_q_up , per_head_k_up )
238238 # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
239239 self .fusedqk = torch .nn .Parameter (fusedqk .detach ().clone ())
240-
241- # self.kv_a_proj_with_mqa_ckv = nn.Linear(self.hidden_size, self.config.kv_lora_rank, bias=self.config.attention_bias )
242- # self.kv_a_proj_with_mqa_k_pe = nn.Linear(self.hidden_size, self.config.qk_rope_head_dim, bias=self.config.attention_bias )
240+ kv_a_proj_with_mqa_ckv , kv_a_proj_with_mqa_k_pe = self . kv_a_proj_with_mqa . weight . T . split ([ self . kv_lora_rank , self . qk_rope_head_dim ], dim = - 1 )
241+ self .kv_a_proj_with_mqa_ckv = torch . nn .Parameter ( kv_a_proj_with_mqa_ckv . detach (). clone () )
242+ self .kv_a_proj_with_mqa_k_pe = torch . nn .Parameter ( kv_a_proj_with_mqa_k_pe . detach (). clone () )
243243
244244 def fused_forward (
245245 self ,
@@ -258,10 +258,8 @@ def fused_forward(
258258 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
259259 bsz , q_len , _ = hidden_states .size ()
260260
261- compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
262- # compressed_kv = self.kv_a_proj_with_mqa_ckv(hidden_states)
263- # k_pe = self.kv_a_proj_with_mqa_k_pe(hidden_states)
264- compressed_kv , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
261+ compressed_kv = torch .matmul (hidden_states , self .kv_a_proj_with_mqa_ckv )
262+ k_pe = torch .matmul (hidden_states , self .kv_a_proj_with_mqa_k_pe )
265263 k_pe = k_pe .view (bsz , q_len , 1 , self .qk_rope_head_dim ).transpose (1 , 2 )
266264
267265 q_a_proj_out = self .q_a_layernorm (self .q_a_proj (hidden_states ))
0 commit comments