|
10 | 10 | from types import MethodType |
11 | 11 | from typing import Callable, Optional, Tuple, Union |
12 | 12 |
|
| 13 | +import torch |
13 | 14 | from torch import nn |
14 | 15 | from transformers.models.codegen.modeling_codegen import ( |
15 | 16 | CodeGenAttention, |
|
456 | 457 | from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry |
457 | 458 | from QEfficient.transformers.sampler.sampler import sampler_forward |
458 | 459 | from QEfficient.transformers.spd.spd_transform_forward import tlm_forward |
| 460 | +from QEfficient.utils.logging_utils import logger |
459 | 461 |
|
460 | 462 | SPD_TARGET = "target" |
461 | 463 |
|
@@ -694,6 +696,82 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): |
694 | 696 | } |
695 | 697 |
|
696 | 698 |
|
| 699 | +class ReplicateKVHeadTransform: |
| 700 | + """ |
| 701 | + Replicates KV heads in attention modules to match the number of KV heads in the target model. |
| 702 | + This transform is used when the source model has fewer KV heads than required in target model. |
| 703 | + """ |
| 704 | + |
| 705 | + def _duplicate_weights_for_linear_layer( |
| 706 | + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int |
| 707 | + ): |
| 708 | + new_kv_heads = repeat #for mla |
| 709 | + |
| 710 | + layer.weight.data = torch.repeat_interleave( |
| 711 | + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 |
| 712 | + ).view(new_kv_heads * head_dim, hidden_size) |
| 713 | + if layer.bias is not None: |
| 714 | + layer.bias.data = torch.repeat_interleave( |
| 715 | + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 |
| 716 | + ).view(new_kv_heads * head_dim) |
| 717 | + if layer.bias is not None: |
| 718 | + layer.bias.data = torch.repeat_interleave( |
| 719 | + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 |
| 720 | + ).view(new_kv_heads * head_dim) |
| 721 | + |
| 722 | + def _get_text_model(model): |
| 723 | + """ |
| 724 | + Determine and return the appropriate text_model from a given model object. |
| 725 | + """ |
| 726 | + # Check for VLMs |
| 727 | + if hasattr(model, "language_model"): |
| 728 | + if hasattr(model.language_model, "model"): |
| 729 | + return model.language_model.model |
| 730 | + else: |
| 731 | + return model.language_model |
| 732 | + # Check for CausalLMs |
| 733 | + if hasattr(model, "model"): |
| 734 | + return model.model |
| 735 | + |
| 736 | + raise AttributeError("No suitable text model found in the provided model.") |
| 737 | + |
| 738 | + @classmethod |
| 739 | + def apply(cls, model: nn.Module, **kwargs) -> nn.Module: |
| 740 | + """ |
| 741 | + Replicates KV heads in attention modules based on provided multiplier. |
| 742 | +
|
| 743 | + Args: |
| 744 | + model: The model to apply the transform to. |
| 745 | + kwargs: Additional arguments for the transformation. Includes: |
| 746 | + - num_kv_heads_repeat: The number of times to repeat the KV heads. |
| 747 | + """ |
| 748 | + n_repeat = kwargs.pop("num_kv_heads_repeat", 1) |
| 749 | + transformed = False |
| 750 | + if n_repeat is not None and n_repeat > 1: |
| 751 | + text_model = cls._get_text_model(model) |
| 752 | + |
| 753 | + orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads |
| 754 | + new_kv_heads = n_repeat*orig_kv_heads |
| 755 | + text_model.config.orig_kv_heads = orig_kv_heads |
| 756 | + text_model.config.num_key_value_heads = new_kv_heads |
| 757 | + |
| 758 | + num_attention_heads = text_model.config.num_attention_heads |
| 759 | + hidden_size = text_model.config.hidden_size |
| 760 | + |
| 761 | + logger.warning(f"Original KV heads: {orig_kv_heads}") |
| 762 | + logger.warning(f"Modified KV heads: {new_kv_heads}") |
| 763 | + transformed = True |
| 764 | + for block in text_model.layers: |
| 765 | + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) |
| 766 | + attn.num_key_value_heads = new_kv_heads |
| 767 | + head_dim = attn.kv_lora_rank+attn.qk_rope_head_dim |
| 768 | + |
| 769 | + cls._duplicate_weights_for_linear_layer( |
| 770 | + attn.kv_a_proj_with_mqa, orig_kv_heads, n_repeat, head_dim, hidden_size |
| 771 | + ) |
| 772 | + return model, transformed |
| 773 | + |
| 774 | + |
697 | 775 | class SpDTransform: |
698 | 776 | """ |
699 | 777 | Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. |
|
0 commit comments