Skip to content

Commit da4797f

Browse files
committed
Add replicatekvhead transform
Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent 0662e58 commit da4797f

File tree

5 files changed

+164
-90
lines changed

5 files changed

+164
-90
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,10 @@ def update_ckv(self, compressed_kv, cache_kwargs):
302302
position_ids = cache_kwargs.get("position_ids")
303303
batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later
304304

305-
self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv)
305+
self.ckv = CtxScatterFunc.apply(self.ckv, position_ids, compressed_kv)
306306

307307
ckv_out = self.ckv
308-
ctx_len = ckv_out.shape[1]
308+
ctx_len = ckv_out.shape[-2]
309309
ctx_indices = torch.arange(ctx_len)[None, ...]
310310
gather_limit = position_ids.max(1, keepdim=True).values
311311
invalid_mask = ctx_indices > gather_limit
@@ -315,7 +315,7 @@ def update_ckv(self, compressed_kv, cache_kwargs):
315315
invalid_idx_value = 0
316316
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
317317

318-
ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices)
318+
ckv_out = CtxGatherFunc.apply(ckv_out, ctx_indices, ctx_len)
319319
ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out)
320320
return ckv_out
321321

QEfficient/transformers/models/deepseek_v3/modeling_deepseek_qeff.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,23 @@ def __qeff_init__(
214214
-1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim
215215
).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
216216
q_up = q_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0)
217-
# self.register_buffer("q_up", q_up.detach().clone(), persistent=False)
217+
218218
self.q_up = torch.nn.Parameter(q_up.detach().clone())
219219
q_rope = q_rope.reshape(-1, self.num_heads * self.qk_rope_head_dim).unsqueeze(0)
220-
# self.register_buffer("q_rope", q_rope.detach().clone(), persistent=False)
220+
221221
self.q_rope = torch.nn.Parameter(q_rope.detach().clone())
222222
k_up, v_up = self.kv_b_proj.weight.T.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
223223
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
224224
)
225225
k_up = k_up.reshape(-1, self.num_heads * self.qk_nope_head_dim).unsqueeze(0)
226226
v_up = v_up.reshape(-1, self.num_heads * self.v_head_dim).unsqueeze(0)
227-
# self.register_buffer("k_up", k_up.detach().clone(), persistent=False)
228-
# self.register_buffer("v_up", v_up.detach().clone(), persistent=False)
227+
229228
self.k_up = torch.nn.Parameter(k_up.detach().clone())
230229
self.v_up = torch.nn.Parameter(v_up.detach().clone())
231230
per_head_q_up = self.q_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
232231
per_head_k_up = (
233232
self.k_up.squeeze(0).view(-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1).transpose(1, 2)
234233
)
235-
# self.register_buffer("per_head_q_up", per_head_q_up.detach().clone(), persistent=False)
236-
# self.register_buffer("per_head_k_up", per_head_k_up.detach().clone(), persistent=False)
237234
self.per_head_q_up = torch.nn.Parameter(per_head_q_up.detach().clone())
238235
self.per_head_k_up = torch.nn.Parameter(per_head_k_up.detach().clone())
239236

@@ -243,12 +240,7 @@ def __qeff_init__(
243240
out = torch.cat((out,x), 0)
244241
fusedqk = out.reshape(self.num_heads, -1, self.kv_lora_rank)
245242

246-
#fusedqk = torch.bmm(per_head_q_up, per_head_k_up)
247-
# self.register_buffer("fusedqk", fusedqk.detach().clone(), persistent=False)
248243
self.fusedqk = torch.nn.Parameter(fusedqk.detach().clone())
249-
kv_a_proj_with_mqa_ckv, kv_a_proj_with_mqa_k_pe = self.kv_a_proj_with_mqa.weight.T.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
250-
self.kv_a_proj_with_mqa_ckv = torch.nn.Parameter(kv_a_proj_with_mqa_ckv.detach().clone())
251-
self.kv_a_proj_with_mqa_k_pe = torch.nn.Parameter(kv_a_proj_with_mqa_k_pe.detach().clone())
252244

253245
def fused_forward(
254246
self,
@@ -267,69 +259,79 @@ def fused_forward(
267259
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
268260
bsz, q_len, _ = hidden_states.size()
269261

270-
compressed_kv = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_ckv)
271-
k_pe = torch.matmul(hidden_states, self.kv_a_proj_with_mqa_k_pe)
272-
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
262+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
263+
compressed_kv = compressed_kv.view(bsz, q_len, -1, self.kv_lora_rank+self.qk_rope_head_dim).transpose(1, 2)
264+
compressed_kv, k_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
273265

274266
q_a_proj_out = self.q_a_layernorm(self.q_a_proj(hidden_states))
275267
q_pe = torch.bmm(q_a_proj_out, self.q_rope)
276268
q_pe = q_pe.view(bsz, q_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2)
277269
q_nope = torch.bmm(q_a_proj_out, self.q_up)
278270
q_nope = q_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
279271

272+
compressed_kv = self.kv_a_layernorm(compressed_kv)
280273
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
281274
if compressed_kvs is not None:
282275
compressed_kv = compressed_kvs.update_ckv(compressed_kv, self.layer_idx, cache_kwargs)
283276

284-
kva = self.kv_a_layernorm(compressed_kv)
285-
k_nope = torch.bmm(kva, self.k_up)
286-
k_nope = k_nope.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
287-
value_states = torch.bmm(kva, self.v_up)
288-
value_states = value_states.view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
289-
290-
cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024)
291-
q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
292-
293-
if compressed_kvs is not None:
294-
k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs)
277+
kva = compressed_kv
295278

296279
if mla_absorption is not None:
297280
enable_absorption = mla_absorption.get("enable", False)
298281
absorb_online = mla_absorption.get("online", False)
299282
else:
300283
enable_absorption = False
301284

285+
n_head_ckv = compressed_kv.shape[1]
286+
p = self.num_heads//n_head_ckv
287+
288+
value_out = []
289+
for i in range(n_head_ckv):
290+
value_states_ph = torch.matmul(kva[:,i,:,:], self.v_up[:, :, i*p*self.v_head_dim: (i+1)*p*self.v_head_dim])
291+
value_states_ph = value_states_ph.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2)
292+
value_out.append(value_states_ph)
293+
value_states = torch.cat(value_out, dim=1)
294+
295+
cos, sin = self.rotary_emb(value_states_ph, seq_len=32 * 1024)
296+
q_pe, k_pe = orig_apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
297+
298+
if compressed_kvs is not None:
299+
k_pe = compressed_kvs.update_k_pe(k_pe, self.layer_idx, cache_kwargs)
300+
302301
x = []
303-
for i in range(self.num_heads):
304-
if enable_absorption:
305-
if absorb_online:
306-
if i==0:
307-
print("online absorption")
308-
out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:])
309-
out = out.reshape(1, -1, self.kv_lora_rank)
310-
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out)
302+
for k in range(n_head_ckv):
303+
k_nope = torch.matmul(kva[:,k,:,:], self.k_up[:, :, k*p*self.qk_nope_head_dim: (k+1)*p*self.qk_nope_head_dim])
304+
k_nope = k_nope.view(bsz, -1, p, self.qk_nope_head_dim).transpose(1, 2)
305+
306+
for i in range(k*p, (k+1)*p):
307+
if enable_absorption:
308+
if absorb_online:
309+
if i==0:
310+
print("online absorption")
311+
out = torch.matmul(self.per_head_q_up[i,:,:], self.per_head_k_up[i,:,:])
312+
out = out.reshape(1, -1, self.kv_lora_rank)
313+
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), out)
314+
else:
315+
if i==0:
316+
print("using fused qk")
317+
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:])
318+
319+
out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1)
320+
kva_kpe = torch.cat((kva[:,k,:,:],k_pe[:,k,:,:]), -1).unsqueeze(1)
321+
attn_weights = torch.matmul(out3, kva_kpe.transpose(2,3)) * self.softmax_scale
311322
else:
312323
if i==0:
313-
print("using fused qk")
314-
out2 = torch.matmul(q_a_proj_out.unsqueeze(1), self.fusedqk[i,:,:])
315-
316-
out3 = torch.cat((out2, q_pe[:,i,:,:].unsqueeze(1)), -1)
317-
kva_kpe = torch.cat((kva,k_pe.squeeze(1)), -1)
318-
attn_weights = torch.matmul(out3, kva_kpe.transpose(1, 2).unsqueeze(1)) * self.softmax_scale
319-
else:
320-
if i==0:
321-
print("no absorption")
322-
query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1)
323-
key_states = torch.cat((k_nope[:,i,:,:].unsqueeze(1), k_pe), -1)
324-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
324+
print("no absorption")
325+
query_states = torch.cat((q_nope[:,i,:,:], q_pe[:,i,:,:]), -1).unsqueeze(1)
326+
key_states = torch.cat((k_nope[:,i%p,:,:], k_pe[:,k,:,:]), -1).unsqueeze(1)
327+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
325328

326-
if attention_mask is not None: # no matter the length, we just slice it
327-
attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights)
329+
if attention_mask is not None: # no matter the length, we just slice it
330+
attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights)
328331

329-
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype)
330-
attn_output = torch.matmul(attn_weights, value_states[:,i,:,:])
331-
332-
x.append(attn_output)
332+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_pe.dtype)
333+
attn_output = torch.matmul(attn_weights, value_states[:,i,:,:])
334+
x.append(attn_output)
333335

334336
attn_output = torch.cat(x, dim=1)
335337

@@ -455,23 +457,6 @@ def forward(self, hidden_states):
455457
hidden_states = hidden_states + self.shared_experts(residuals)
456458
return hidden_states
457459

458-
# def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
459-
# final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
460-
# expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
461-
# expert_mask = expert_mask.permute(2, 0, 1)
462-
463-
# for expert_idx in range(len(self.experts)):
464-
# expert = self.experts[expert_idx]
465-
# mask = expert_mask[expert_idx]
466-
# expert_output = expert(hidden_states) * (((topk_weights * mask).sum(1))[:, None])
467-
# expert_output = torch.where(
468-
# (topk_weights * mask).sum(1).to(torch.bool)[:, None],
469-
# expert_output,
470-
# torch.tensor(0.0),
471-
# )
472-
# final_hidden_states = final_hidden_states + expert_output
473-
# return final_hidden_states.type(hidden_states.dtype)
474-
475460

476461
class QEffPrefillOnlyDeepseekV3MoE(nn.Module):
477462
def __qeff_init__(

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
PrefillOnlyExternalModuleMapperTransform,
5252
PrefillOnlyChunkedTransform,
5353
PrefillOnlyTransform,
54+
ReplicateKVHeadTransform,
5455
RevertPrefillKeepAttentionTransform,
5556
RevertPrefillOnlyTransform,
5657
RevertPrefillOnlyExternalModuleMapperTransform,
@@ -2410,6 +2411,11 @@ def __init__(
24102411
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
24112412
self.hash_params["max_seq_len_cached"] = max_seq_len_cached
24122413

2414+
if self.model.config.model_type in {"kimi_k2"}:
2415+
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
2416+
if replicate_kv_transformed:
2417+
self.hash_params["config"] = model.config.to_diff_dict()
2418+
24132419
# ---Sampling---
24142420
# Note: SamplerTransform should be applied after all other transforms
24152421
# are done. The role of the sampler is to just add nodes at the output of the
@@ -2746,11 +2752,11 @@ def export(
27462752
output_names = [v for v in output_names if "past" not in v]
27472753
example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)]
27482754
for i in range(self.num_layers):
2749-
ckv = torch.zeros((bs, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32)
2750-
k_pe = torch.zeros((bs, 1, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32)
2755+
ckv = torch.zeros((bs, 4, seq_len, self.model.config.kv_lora_rank), dtype=torch.float32)
2756+
k_pe = torch.zeros((bs, 4, seq_len, self.model.config.qk_rope_head_dim), dtype=torch.float32)
27512757
example_inputs["compressed_kvs"][i].append(ckv)
27522758
example_inputs["compressed_kvs"][i].append(k_pe)
2753-
dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 1: "ctx_len"}
2759+
dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"}
27542760
dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"}
27552761
output_names.append(f"compressed_kv.{i}_RetainedState")
27562762
output_names.append(f"k_pe.{i}_RetainedState")

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from types import MethodType
1111
from typing import Callable, Optional, Tuple, Union
1212

13+
import torch
1314
from torch import nn
1415
from transformers.models.codegen.modeling_codegen import (
1516
CodeGenAttention,
@@ -456,6 +457,7 @@
456457
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
457458
from QEfficient.transformers.sampler.sampler import sampler_forward
458459
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
460+
from QEfficient.utils.logging_utils import logger
459461

460462
SPD_TARGET = "target"
461463

@@ -694,6 +696,79 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform):
694696
}
695697

696698

699+
class ReplicateKVHeadTransform:
700+
"""
701+
Replicates KV heads in attention modules to match the number of KV heads in the target model.
702+
This transform is used when the source model has fewer KV heads than required in target model.
703+
"""
704+
705+
def _duplicate_weights_for_linear_layer(
706+
layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int
707+
):
708+
new_kv_heads = repeat #for mla
709+
710+
layer.weight.data = torch.repeat_interleave(
711+
layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0
712+
).view(new_kv_heads * dim, hidden_size)
713+
714+
if layer.bias is not None:
715+
layer.bias.data = torch.repeat_interleave(
716+
layer.bias.data.view(orig_kv_heads, dim), repeat, 0
717+
).view(new_kv_heads * dim)
718+
719+
def _get_text_model(model):
720+
"""
721+
Determine and return the appropriate text_model from a given model object.
722+
"""
723+
# Check for VLMs
724+
if hasattr(model, "language_model"):
725+
if hasattr(model.language_model, "model"):
726+
return model.language_model.model
727+
else:
728+
return model.language_model
729+
# Check for CausalLMs
730+
if hasattr(model, "model"):
731+
return model.model
732+
733+
raise AttributeError("No suitable text model found in the provided model.")
734+
735+
@classmethod
736+
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
737+
"""
738+
Replicates KV heads in attention modules based on provided multiplier.
739+
740+
Args:
741+
model: The model to apply the transform to.
742+
kwargs: Additional arguments for the transformation. Includes:
743+
- num_kv_heads_repeat: The number of times to repeat the KV heads.
744+
"""
745+
n_repeat = kwargs.pop("num_kv_heads_repeat", 1)
746+
transformed = False
747+
if n_repeat is not None and n_repeat > 1:
748+
text_model = cls._get_text_model(model)
749+
750+
orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads
751+
new_kv_heads = n_repeat*orig_kv_heads
752+
text_model.config.orig_kv_heads = orig_kv_heads
753+
text_model.config.num_key_value_heads = new_kv_heads
754+
755+
num_attention_heads = text_model.config.num_attention_heads
756+
hidden_size = text_model.config.hidden_size
757+
758+
logger.warning(f"Original KV heads: {orig_kv_heads}")
759+
logger.warning(f"Modified KV heads: {new_kv_heads}")
760+
transformed = True
761+
for block in text_model.layers:
762+
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
763+
attn.num_key_value_heads = new_kv_heads
764+
head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim
765+
766+
cls._duplicate_weights_for_linear_layer(
767+
attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size
768+
)
769+
return model, transformed
770+
771+
697772
class SpDTransform:
698773
"""
699774
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.

0 commit comments

Comments
 (0)