Skip to content

Commit 5711faa

Browse files
committed
Add replicatekvhead transform
Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent 0662e58 commit 5711faa

File tree

3 files changed

+103
-10
lines changed

3 files changed

+103
-10
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
PrefillOnlyExternalModuleMapperTransform,
5252
PrefillOnlyChunkedTransform,
5353
PrefillOnlyTransform,
54+
ReplicateKVHeadTransform,
5455
RevertPrefillKeepAttentionTransform,
5556
RevertPrefillOnlyTransform,
5657
RevertPrefillOnlyExternalModuleMapperTransform,
@@ -2410,6 +2411,10 @@ def __init__(
24102411
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
24112412
self.hash_params["max_seq_len_cached"] = max_seq_len_cached
24122413

2414+
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
2415+
if replicate_kv_transformed:
2416+
self.hash_params["config"] = model.config.to_diff_dict()
2417+
24132418
# ---Sampling---
24142419
# Note: SamplerTransform should be applied after all other transforms
24152420
# are done. The role of the sampler is to just add nodes at the output of the

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from types import MethodType
1111
from typing import Callable, Optional, Tuple, Union
1212

13+
import torch
1314
from torch import nn
1415
from transformers.models.codegen.modeling_codegen import (
1516
CodeGenAttention,
@@ -456,6 +457,7 @@
456457
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
457458
from QEfficient.transformers.sampler.sampler import sampler_forward
458459
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
460+
from QEfficient.utils.logging_utils import logger
459461

460462
SPD_TARGET = "target"
461463

@@ -694,6 +696,82 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform):
694696
}
695697

696698

699+
class ReplicateKVHeadTransform:
700+
"""
701+
Replicates KV heads in attention modules to match the number of KV heads in the target model.
702+
This transform is used when the source model has fewer KV heads than required in target model.
703+
"""
704+
705+
def _duplicate_weights_for_linear_layer(
706+
layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
707+
):
708+
new_kv_heads = repeat #for mla
709+
710+
layer.weight.data = torch.repeat_interleave(
711+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
712+
).view(new_kv_heads * head_dim, hidden_size)
713+
if layer.bias is not None:
714+
layer.bias.data = torch.repeat_interleave(
715+
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
716+
).view(new_kv_heads * head_dim)
717+
if layer.bias is not None:
718+
layer.bias.data = torch.repeat_interleave(
719+
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
720+
).view(new_kv_heads * head_dim)
721+
722+
def _get_text_model(model):
723+
"""
724+
Determine and return the appropriate text_model from a given model object.
725+
"""
726+
# Check for VLMs
727+
if hasattr(model, "language_model"):
728+
if hasattr(model.language_model, "model"):
729+
return model.language_model.model
730+
else:
731+
return model.language_model
732+
# Check for CausalLMs
733+
if hasattr(model, "model"):
734+
return model.model
735+
736+
raise AttributeError("No suitable text model found in the provided model.")
737+
738+
@classmethod
739+
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
740+
"""
741+
Replicates KV heads in attention modules based on provided multiplier.
742+
743+
Args:
744+
model: The model to apply the transform to.
745+
kwargs: Additional arguments for the transformation. Includes:
746+
- num_kv_heads_repeat: The number of times to repeat the KV heads.
747+
"""
748+
n_repeat = kwargs.pop("num_kv_heads_repeat", 1)
749+
transformed = False
750+
if n_repeat is not None and n_repeat > 1:
751+
text_model = cls._get_text_model(model)
752+
753+
orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads
754+
new_kv_heads = n_repeat*orig_kv_heads
755+
text_model.config.orig_kv_heads = orig_kv_heads
756+
text_model.config.num_key_value_heads = new_kv_heads
757+
758+
num_attention_heads = text_model.config.num_attention_heads
759+
hidden_size = text_model.config.hidden_size
760+
761+
logger.warning(f"Original KV heads: {orig_kv_heads}")
762+
logger.warning(f"Modified KV heads: {new_kv_heads}")
763+
transformed = True
764+
for block in text_model.layers:
765+
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
766+
attn.num_key_value_heads = new_kv_heads
767+
head_dim = attn.kv_lora_rank+attn.qk_rope_head_dim
768+
769+
cls._duplicate_weights_for_linear_layer(
770+
attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size
771+
)
772+
return model, transformed
773+
774+
697775
class SpDTransform:
698776
"""
699777
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.

examples/run_kimik2.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
from QEfficient import QEFFAutoModelForCausalLM
66

77
prompt = "Once upon a time,"
8+
num_kv_heads_repeat=4 #TS=4
9+
num_hidden_layers=2
10+
enable_mla=True
11+
mla_absorption_config={"enable": True, "online": True}
812

9-
model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/"
13+
#model_path = "/home/ochougul/.cache/huggingface/hub/models--moonshotai--Kimi-K2-Thinking/snapshots/a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55/"
14+
model_path ="/home/huggingface_hub/models--moonshotai--Kimi-K2-Thinking/snapshots/612681931a8c906ddb349f8ad0f582cb552189cd"
1015
model = AutoModelForCausalLM.from_pretrained(
11-
model_path, torch_dtype=torch.float32, num_hidden_layers=2, trust_remote_code=True
16+
model_path, torch_dtype=torch.float32, num_hidden_layers=num_hidden_layers, trust_remote_code=True
1217
)
1318
tokenizer = AutoTokenizer.from_pretrained("moonshotai/Kimi-K2-Thinking", trust_remote_code=True)
1419

@@ -27,8 +32,8 @@
2732
out = model(**inputs)
2833
predictions = torch.argmax(out.logits, dim=-1)
2934

30-
qeff_model = QEFFAutoModelForCausalLM(model)
31-
qeff_model.mla(enable_mla=True, mla_absorption_config={"enable": True, "online": True})
35+
qeff_model = QEFFAutoModelForCausalLM(model, num_kv_heads_repeat=num_kv_heads_repeat)
36+
qeff_model.mla(enable_mla=enable_mla, mla_absorption_config=mla_absorption_config)
3237

3338
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
3439
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
@@ -84,17 +89,22 @@
8489
print("Completion:", repr(predicted_string))
8590

8691

92+
93+
prefill_seq_len = 128
94+
ctx_len = 2048
95+
8796
onnx_path = qeff_model.export(
88-
prefill_seq_len=1, enable_mla=True, mla_absorption_config={"enable": True, "online": True}
97+
prefill_seq_len=prefill_seq_len, enable_mla=enable_mla, mla_absorption_config=mla_absorption_config
8998
)
99+
90100
qpc_path = qeff_model.compile(
91-
prefill_seq_len=1,
92-
ctx_len=1024,
93-
enable_mla=True,
94-
mla_absorption_config={"enable": True, "online": True},
101+
prefill_seq_len=prefill_seq_len,
102+
ctx_len=ctx_len,
103+
enable_mla=enable_mla,
104+
mla_absorption_config=mla_absorption_config,
95105
mxfp6_matmul=True,
96106
mxint8_kv_cache=False,
97-
num_devices=1,
107+
num_devices=num_kv_heads_repeat,
98108
num_cores=16,
99109
)
100110

0 commit comments

Comments
 (0)