Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache
from transformers.cache_utils import Cache, CacheLayerMixin, EncoderDecoderCache, HybridCache, HybridChunkedCache

from QEfficient.customop import (
CtxGatherFunc,
Expand Down Expand Up @@ -54,7 +54,47 @@ def _get_invalid_idx_value(cls):
return 0


class QEffDynamicLayer(DynamicLayer):
class QEffDynamicLayer(CacheLayerMixin):
is_sliding = False

def __init__(self):
super().__init__()

def lazy_initialization(self, key_states: torch.Tensor):
self.dtype = key_states.dtype
self.device = key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
kv_offset = 0
query_length = cache_position.shape[0]
kv_length = self.get_seq_length() + query_length
return kv_length, kv_offset

def get_seq_length(self) -> int:
if self.keys is None or self.keys.numel() == 0:
return 0
return self.keys.shape[-2]

def get_max_cache_shape(self) -> int:
return -1

@classmethod
def from_tensors(cls, key_states: torch.Tensor, value_states: torch.Tensor) -> "QEffDynamicLayer":
layer = cls()
layer.keys = key_states
layer.values = value_states
layer._mark_initialized(key_states)
return layer

def _mark_initialized(self, reference_states: torch.Tensor) -> None:
if not self.is_initialized:
self.dtype = reference_states.dtype
self.device = reference_states.device
self.is_initialized = True

def read_only(self, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer.
Expand All @@ -68,6 +108,8 @@ def read_only(self, cache_kwargs):
"""
# Gather
k_out, v_out = self.keys, self.values
if k_out is not None:
self._mark_initialized(k_out)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
ctx_len = cache_kwargs.get("CCL", k_out.shape[2])
Expand Down Expand Up @@ -109,6 +151,8 @@ def read_only_blockedKV(self, start_index, end_index, cache_kwargs):
"""
# Gather
k_out, v_out = self.keys, self.values
if k_out is not None:
self._mark_initialized(k_out)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
batch, num_kv_heads, _, _ = k_out.shape
Expand Down Expand Up @@ -150,7 +194,9 @@ def write_only(self, key_states, value_states, cache_kwargs):
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

Expand Down Expand Up @@ -189,8 +235,10 @@ def update(
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
k_out, v_out = self.keys, self.values
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

Expand Down Expand Up @@ -252,8 +300,10 @@ def update3D(
if self.keys is None:
self.keys = key_states
self.values = value_states
self._mark_initialized(self.keys)
k_out, v_out = self.keys, self.values
else:
self._mark_initialized(self.keys)
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)

Expand Down Expand Up @@ -293,7 +343,7 @@ def update3D(
return k_out, v_out


class QEffDynamicCache(DynamicCache):
class QEffDynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.

Expand All @@ -307,15 +357,46 @@ class QEffDynamicCache(DynamicCache):
"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
# Remove layer_classes if present to avoid duplicate argument
# Remove cache-layer construction args if present to avoid duplicate arguments.
kwargs.pop("layer_classes", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
kwargs.pop("layers", None)
kwargs.pop("layer_class_to_replicate", None)

try:
# transformers>=4.57
Cache.__init__(self, *args, layer_class_to_replicate=QEffDynamicLayer, **kwargs)
except TypeError:
# transformers<=4.56
Cache.__init__(self, *args, layer_classes=QEffDynamicLayer, **kwargs)
if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))

def append_new_layers(self, layer_idx: int) -> None:
while len(self.layers) <= layer_idx:
self.layers.append(QEffDynamicLayer())

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "QEffDynamicCache":
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
legacy_cache = ()
for layer in self.layers:
legacy_cache += ((layer.keys, layer.values),)
return legacy_cache

def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""
Keep backward-compatible call shape while deferring to upstream implementation.
"""
return super().get_seq_length(layer_idx)

def read_only(self, layer_idx, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer `layer_idx`.
Expand Down Expand Up @@ -405,10 +486,7 @@ def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(
self_attention_cache=QEffDynamicCache(),
cross_attention_cache=QEffDynamicCache(),
)
cache = cls(QEffDynamicCache(), QEffDynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
Expand All @@ -419,6 +497,18 @@ def from_legacy_cache(
cache.is_updated[layer_idx] = True
return cache

def to_legacy_cache(self):
self_attn_legacy = self.self_attention_cache.to_legacy_cache()
cross_attn_legacy = self.cross_attention_cache.to_legacy_cache()

legacy_cache = ()
for layer_idx, self_attn_layer in enumerate(self_attn_legacy):
if layer_idx < len(cross_attn_legacy):
legacy_cache += (self_attn_layer + cross_attn_legacy[layer_idx],)
else:
legacy_cache += (self_attn_layer,)
return legacy_cache


# TODO:This function will be depercated in future.
class QEffHybridCache(HybridCache):
Expand Down Expand Up @@ -447,7 +537,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -531,7 +621,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -663,7 +753,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down Expand Up @@ -783,7 +873,7 @@ def __len__(self):
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
]
)


# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

Expand Down
31 changes: 15 additions & 16 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -108,9 +98,6 @@ class QEffFalconAttention(FalconAttention):
- add new args position idx for the cache_kwargs for kv retention
"""

def __qeff_init__(self):
self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -125,6 +112,8 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cos_cached: Optional[torch.Tensor] = None,
sin_cached: Optional[torch.Tensor] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
Expand All @@ -137,9 +126,8 @@ def forward(
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids)

if layer_past is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
Expand Down Expand Up @@ -184,6 +172,8 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
sin_cached=None,
cos_cached=None,
**kwargs,
):
residual = hidden_states
Expand All @@ -208,6 +198,8 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
sin_cached=sin_cached,
cos_cached=cos_cached,
)

if not self.config.new_decoder_architecture:
Expand Down Expand Up @@ -245,6 +237,11 @@ class QEffFalconModel(FalconModel):
- update causal attention mask
"""

def __qeff_init__(self):
self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config)
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling)
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling)

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -322,6 +319,8 @@ def forward(
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
)

hidden_states = outputs[0]
Expand Down
Loading
Loading