@@ -214,26 +214,23 @@ def __qeff_init__(
214214 - 1 , self .num_heads , self .qk_nope_head_dim + self .qk_rope_head_dim
215215 ).split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
216216 q_up = q_up .reshape (- 1 , self .num_heads * self .qk_nope_head_dim ).unsqueeze (0 )
217- # self.register_buffer("q_up", q_up.detach().clone(), persistent=False)
217+
218218 self .q_up = torch .nn .Parameter (q_up .detach ().clone ())
219219 q_rope = q_rope .reshape (- 1 , self .num_heads * self .qk_rope_head_dim ).unsqueeze (0 )
220- # self.register_buffer("q_rope", q_rope.detach().clone(), persistent=False)
220+
221221 self .q_rope = torch .nn .Parameter (q_rope .detach ().clone ())
222222 k_up , v_up = self .kv_b_proj .weight .T .view (- 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim ).split (
223223 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1
224224 )
225225 k_up = k_up .reshape (- 1 , self .num_heads * self .qk_nope_head_dim ).unsqueeze (0 )
226226 v_up = v_up .reshape (- 1 , self .num_heads * self .v_head_dim ).unsqueeze (0 )
227- # self.register_buffer("k_up", k_up.detach().clone(), persistent=False)
228- # self.register_buffer("v_up", v_up.detach().clone(), persistent=False)
227+
229228 self .k_up = torch .nn .Parameter (k_up .detach ().clone ())
230229 self .v_up = torch .nn .Parameter (v_up .detach ().clone ())
231230 per_head_q_up = self .q_up .squeeze (0 ).view (- 1 , self .num_heads , self .qk_nope_head_dim ).transpose (0 , 1 )
232231 per_head_k_up = (
233232 self .k_up .squeeze (0 ).view (- 1 , self .num_heads , self .qk_nope_head_dim ).transpose (0 , 1 ).transpose (1 , 2 )
234233 )
235- # self.register_buffer("per_head_q_up", per_head_q_up.detach().clone(), persistent=False)
236- # self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False)
237234 self .per_head_q_up = torch .nn .Parameter (per_head_q_up .detach ().clone ())
238235 self .per_head_k_up = torch .nn .Parameter (per_head_k_up .detach ().clone ())
239236
@@ -243,12 +240,7 @@ def __qeff_init__(
243240 out = torch .cat ((out ,x ), 0 )
244241 fusedqk = out .reshape (self .num_heads , - 1 , self .kv_lora_rank )
245242
246- #fusedqk = torch.bmm(per_head_q_up, per_head_k_up)
247- # self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
248243 self .fusedqk = torch .nn .Parameter (fusedqk .detach ().clone ())
249- kv_a_proj_with_mqa_ckv , kv_a_proj_with_mqa_k_pe = self .kv_a_proj_with_mqa .weight .T .split ([self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
250- self .kv_a_proj_with_mqa_ckv = torch .nn .Parameter (kv_a_proj_with_mqa_ckv .detach ().clone ())
251- self .kv_a_proj_with_mqa_k_pe = torch .nn .Parameter (kv_a_proj_with_mqa_k_pe .detach ().clone ())
252244
253245 def fused_forward (
254246 self ,
@@ -267,69 +259,79 @@ def fused_forward(
267259 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
268260 bsz , q_len , _ = hidden_states .size ()
269261
270- compressed_kv = torch . matmul (hidden_states , self . kv_a_proj_with_mqa_ckv )
271- k_pe = torch . matmul ( hidden_states , self .kv_a_proj_with_mqa_k_pe )
272- k_pe = k_pe . view ( bsz , q_len , 1 , self .qk_rope_head_dim ). transpose ( 1 , 2 )
262+ compressed_kv = self . kv_a_proj_with_mqa (hidden_states )
263+ compressed_kv = compressed_kv . view ( bsz , q_len , - 1 , self .kv_lora_rank + self . qk_rope_head_dim ). transpose ( 1 , 2 )
264+ compressed_kv , k_pe = compressed_kv . split ([ self . kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
273265
274266 q_a_proj_out = self .q_a_layernorm (self .q_a_proj (hidden_states ))
275267 q_pe = torch .bmm (q_a_proj_out , self .q_rope )
276268 q_pe = q_pe .view (bsz , q_len , self .num_heads , self .qk_rope_head_dim ).transpose (1 , 2 )
277269 q_nope = torch .bmm (q_a_proj_out , self .q_up )
278270 q_nope = q_nope .view (bsz , q_len , self .num_heads , self .qk_nope_head_dim ).transpose (1 , 2 )
279271
272+ compressed_kv = self .kv_a_layernorm (compressed_kv )
280273 cache_kwargs = {"position_ids" : position_ids , "batch_index" : batch_index }
281274 if compressed_kvs is not None :
282275 compressed_kv = compressed_kvs .update_ckv (compressed_kv , self .layer_idx , cache_kwargs )
283276
284- kva = self .kv_a_layernorm (compressed_kv )
285- k_nope = torch .bmm (kva , self .k_up )
286- k_nope = k_nope .view (bsz , - 1 , self .num_heads , self .qk_nope_head_dim ).transpose (1 , 2 )
287- value_states = torch .bmm (kva , self .v_up )
288- value_states = value_states .view (bsz , - 1 , self .num_heads , self .qk_nope_head_dim ).transpose (1 , 2 )
289-
290- cos , sin = self .rotary_emb (value_states , seq_len = 32 * 1024 )
291- q_pe , k_pe = orig_apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids )
292-
293- if compressed_kvs is not None :
294- k_pe = compressed_kvs .update_k_pe (k_pe , self .layer_idx , cache_kwargs )
277+ kva = compressed_kv
295278
296279 if mla_absorption is not None :
297280 enable_absorption = mla_absorption .get ("enable" , False )
298281 absorb_online = mla_absorption .get ("online" , False )
299282 else :
300283 enable_absorption = False
301284
285+ n_head_ckv = compressed_kv .shape [1 ]
286+ p = self .num_heads // n_head_ckv
287+
288+ value_out = []
289+ for i in range (n_head_ckv ):
290+ value_states_ph = torch .matmul (kva [:,i ,:,:], self .v_up [:, :, i * p * self .v_head_dim : (i + 1 )* p * self .v_head_dim ])
291+ value_states_ph = value_states_ph .view (bsz , - 1 , p , self .qk_nope_head_dim ).transpose (1 , 2 )
292+ value_out .append (value_states_ph )
293+ value_states = torch .cat (value_out , dim = 1 )
294+
295+ cos , sin = self .rotary_emb (value_states_ph , seq_len = 32 * 1024 )
296+ q_pe , k_pe = orig_apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids )
297+
298+ if compressed_kvs is not None :
299+ k_pe = compressed_kvs .update_k_pe (k_pe , self .layer_idx , cache_kwargs )
300+
302301 x = []
303- for i in range (self .num_heads ):
304- if enable_absorption :
305- if absorb_online :
306- if i == 0 :
307- print ("online absorption" )
308- out = torch .matmul (self .per_head_q_up [i ,:,:], self .per_head_k_up [i ,:,:])
309- out = out .reshape (1 , - 1 , self .kv_lora_rank )
310- out2 = torch .matmul (q_a_proj_out .unsqueeze (1 ), out )
302+ for k in range (n_head_ckv ):
303+ k_nope = torch .matmul (kva [:,k ,:,:], self .k_up [:, :, k * p * self .qk_nope_head_dim : (k + 1 )* p * self .qk_nope_head_dim ])
304+ k_nope = k_nope .view (bsz , - 1 , p , self .qk_nope_head_dim ).transpose (1 , 2 )
305+
306+ for i in range (k * p , (k + 1 )* p ):
307+ if enable_absorption :
308+ if absorb_online :
309+ if i == 0 :
310+ print ("online absorption" )
311+ out = torch .matmul (self .per_head_q_up [i ,:,:], self .per_head_k_up [i ,:,:])
312+ out = out .reshape (1 , - 1 , self .kv_lora_rank )
313+ out2 = torch .matmul (q_a_proj_out .unsqueeze (1 ), out )
314+ else :
315+ if i == 0 :
316+ print ("using fused qk" )
317+ out2 = torch .matmul (q_a_proj_out .unsqueeze (1 ), self .fusedqk [i ,:,:])
318+
319+ out3 = torch .cat ((out2 , q_pe [:,i ,:,:].unsqueeze (1 )), - 1 )
320+ kva_kpe = torch .cat ((kva [:,k ,:,:],k_pe [:,k ,:,:]), - 1 ).unsqueeze (1 )
321+ attn_weights = torch .matmul (out3 , kva_kpe .transpose (2 ,3 )) * self .softmax_scale
311322 else :
312323 if i == 0 :
313- print ("using fused qk" )
314- out2 = torch .matmul (q_a_proj_out .unsqueeze (1 ), self .fusedqk [i ,:,:])
315-
316- out3 = torch .cat ((out2 , q_pe [:,i ,:,:].unsqueeze (1 )), - 1 )
317- kva_kpe = torch .cat ((kva ,k_pe .squeeze (1 )), - 1 )
318- attn_weights = torch .matmul (out3 , kva_kpe .transpose (1 , 2 ).unsqueeze (1 )) * self .softmax_scale
319- else :
320- if i == 0 :
321- print ("no absorption" )
322- query_states = torch .cat ((q_nope [:,i ,:,:], q_pe [:,i ,:,:]), - 1 )
323- key_states = torch .cat ((k_nope [:,i ,:,:].unsqueeze (1 ), k_pe ), - 1 )
324- attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) * self .softmax_scale
324+ print ("no absorption" )
325+ query_states = torch .cat ((q_nope [:,i ,:,:], q_pe [:,i ,:,:]), - 1 ).unsqueeze (1 )
326+ key_states = torch .cat ((k_nope [:,i % p ,:,:], k_pe [:,k ,:,:]), - 1 ).unsqueeze (1 )
327+ attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) * self .softmax_scale
325328
326- if attention_mask is not None : # no matter the length, we just slice it
327- attn_weights = torch .where (attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights )
329+ if attention_mask is not None : # no matter the length, we just slice it
330+ attn_weights = torch .where (attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights )
328331
329- attn_weights = F .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q_pe .dtype )
330- attn_output = torch .matmul (attn_weights , value_states [:,i ,:,:])
331-
332- x .append (attn_output )
332+ attn_weights = F .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q_pe .dtype )
333+ attn_output = torch .matmul (attn_weights , value_states [:,i ,:,:])
334+ x .append (attn_output )
333335
334336 attn_output = torch .cat (x , dim = 1 )
335337
@@ -455,23 +457,6 @@ def forward(self, hidden_states):
455457 hidden_states = hidden_states + self .shared_experts (residuals )
456458 return hidden_states
457459
458- # def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
459- # final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
460- # expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
461- # expert_mask = expert_mask.permute(2, 0, 1)
462-
463- # for expert_idx in range(len(self.experts)):
464- # expert = self.experts[expert_idx]
465- # mask = expert_mask[expert_idx]
466- # expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None])
467- # expert_output = torch.where(
468- # (topk_weights * mask).sum(1).to(torch.bool)[:, None],
469- # expert_output,
470- # torch.tensor(0.0),
471- # )
472- # final_hidden_states = final_hidden_states + expert_output
473- # return final_hidden_states.type(hidden_states.dtype)
474-
475460
476461class QEffPrefillOnlyDeepseekV3MoE (nn .Module ):
477462 def __qeff_init__ (
0 commit comments