diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4ebb2fb96..9647d5afe 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -108,9 +108,6 @@ class QEffFalconAttention(FalconAttention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -125,6 +122,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -138,7 +136,7 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: @@ -184,6 +182,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ): residual = hidden_states @@ -208,6 +207,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, + rotary_emb=rotary_emb, ) if not self.config.new_decoder_architecture: @@ -305,6 +305,8 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffFalconRotaryEmbedding(config=self.config) + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -322,6 +324,7 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = outputs[0] diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 260d1857a..f554e7d8c 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -128,9 +128,6 @@ class QEffGemmaAttention(GemmaAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -140,6 +137,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -150,7 +148,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -194,6 +192,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -223,6 +222,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -297,6 +297,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -310,6 +311,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 6dee8c85d..640065587 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -135,9 +135,6 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -147,6 +144,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -157,7 +155,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +206,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +240,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -341,6 +341,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -355,6 +356,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index f98bae225..4f8380738 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -187,27 +187,6 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): # Set the init in the module mapping pytorch transforms self.__qeff_init__() - def __qeff_init__(self): - self.rotary_emb = QEffGemma3RotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.config.max_position_embeddings, - base=self.config.rope_theta, - ) - - config = copy.deepcopy(self.config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default", "factor": 1.0} - self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) - self.window = self.config.sliding_window if self.is_local else None - - self.rotary_emb_local = QEffGemma3RotaryEmbedding( - self.head_dim, - config, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - def forward( self, hidden_states: torch.Tensor, @@ -218,6 +197,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, + rotary_emb_local: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -240,9 +221,9 @@ def forward( "with a layer index." ) if self.is_sliding: - cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) else: - cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb(value_states, seq_len=self.config.max_position_embeddings) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -308,6 +289,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, last_cache_position: int = 0, + rotary_emb=None, + rotary_emb_local=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -336,6 +319,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **kwargs, ) @@ -432,6 +417,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffGemma3RotaryEmbedding( + self.layers[0].self_attn.head_dim, + self.config, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default", "factor": 1.0} + + rotary_emb_local = QEffGemma3RotaryEmbedding( + self.layers[0].self_attn.head_dim, + config, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -447,6 +451,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, last_cache_position=last_cache_position, + rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **flash_attn_kwargs, ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index e8f5fa89b..fd8393713 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -737,9 +737,6 @@ def opt_eager_attention_forward_blocked( class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -751,6 +748,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -761,7 +759,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -823,9 +821,6 @@ def forward( class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -837,6 +832,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -847,7 +843,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -905,9 +901,6 @@ def forward( class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -919,6 +912,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -929,7 +923,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -986,6 +980,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC sliding_mask=None, + rotary_emb=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states @@ -1002,6 +997,7 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -1079,6 +1075,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1093,6 +1090,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = layer_outputs[0] @@ -1172,6 +1170,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1187,6 +1186,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 8a32c52ef..bd042f145 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -121,9 +121,6 @@ def eager_attention_forward( class QEffGraniteAttention(GraniteAttention): - def __qeff_init__(self): - self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -133,6 +130,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -143,7 +141,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -192,6 +190,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + rotary_emb=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -230,6 +229,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -302,6 +302,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -316,6 +318,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 935df7c2d..cc1e4449d 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -111,9 +111,6 @@ def qeff_apply_rotary_pos_emb( class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -126,6 +123,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -138,7 +136,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -214,6 +212,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -255,6 +254,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) @@ -341,6 +341,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -356,6 +358,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) else: layer_outputs = decoder_layer( @@ -368,6 +371,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 57bccdb1b..11f72a435 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -198,9 +198,6 @@ def eager_attention_forward_blockedKV( class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +209,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -228,7 +226,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -287,6 +285,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -303,6 +302,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -367,6 +367,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -380,6 +382,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index e219d5e03..30ab93d45 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -82,8 +82,6 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) - def forward( self, hidden_states: torch.Tensor, @@ -92,6 +90,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() q_len = 1 # as we always run this for single token @@ -113,7 +112,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( query_states, torch.empty_like(query_states), cos, sin, position_ids @@ -162,6 +161,7 @@ def forward( comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, + rotary_emb=None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -174,6 +174,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -335,6 +336,7 @@ def forward( hidden_states = inputs_embeds next_decoder_cache = None + rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) for layer_idx in range(self.config.num_key_value_layers): layer = self.layers[layer_idx] @@ -347,6 +349,7 @@ def forward( batch_index=batch_index, output_attentions=False, use_cache=True, + rotary_emb=rotary_emb, ) bsz, q_len, _ = hidden_states.size() @@ -372,7 +375,7 @@ def forward( ) kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 47107384e..450aa96e4 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -131,9 +131,6 @@ class QEffMistralAttention(MistralAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -146,6 +143,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, # kept here for BC + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -160,7 +158,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -205,6 +203,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -236,6 +235,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -314,6 +314,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffMistralRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -328,6 +330,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 680c839ae..abe3f4c5a 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -128,9 +128,6 @@ def eager_attention_forward( class QEffMixtralAttention(MixtralAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -139,6 +136,7 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -156,7 +154,7 @@ def forward( "with a layer index." ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -265,6 +263,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -301,6 +300,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -382,6 +382,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -397,6 +399,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 3cba022b4..de6996882 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -241,9 +241,6 @@ class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -255,6 +252,7 @@ def forward( position_embeddings: torch.Tensor = None, use_cache: bool = False, cache_position=None, + rotary_emb: Optional[object] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -276,7 +274,7 @@ def forward( ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -326,6 +324,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb: Optional[object] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -361,6 +360,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -465,6 +465,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -646,6 +647,7 @@ def forward( # embed positions hidden_states = inputs_embeds + rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states @@ -676,6 +678,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index c79ad7fae..889de474e 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -123,9 +123,6 @@ def eager_attention_forward( class QEffOlmo2Attention(Olmo2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -135,6 +132,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -151,7 +149,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -198,6 +196,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -213,6 +212,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -284,6 +284,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -297,6 +299,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b48ab2897..a470d6e79 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -129,9 +129,6 @@ class QEffPhi3Attention(Phi3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -142,6 +139,7 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -158,7 +156,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -207,6 +205,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -244,6 +243,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) @@ -311,6 +311,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -324,6 +326,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 841df6526..eadc53b3f 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -141,9 +141,6 @@ class QEffQwen2Attention(Qwen2Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -153,6 +150,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -163,7 +161,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +206,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -240,6 +239,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -311,6 +311,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -324,6 +326,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d6bfbda81..6e375598b 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -564,9 +564,6 @@ class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): and "Generating Long Sequences with Sparse Transformers". """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -579,6 +576,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -595,7 +593,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] @@ -661,6 +659,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -701,6 +700,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -771,6 +771,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -785,6 +787,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ccc4bbac2..1531149db 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -142,9 +142,6 @@ class QEffQwen3Attention(Qwen3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -154,6 +151,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -164,7 +162,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -209,6 +207,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +240,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -312,6 +312,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -325,6 +327,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6bdd5e243..f52107115 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -188,9 +188,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class QEffQwen3MoeAttention(Qwen3MoeAttention): - def __qeff_init__(self): - self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -200,6 +197,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -210,7 +208,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -247,6 +245,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -279,6 +278,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -336,6 +336,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -349,6 +351,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py new file mode 100644 index 000000000..f17edab65 --- /dev/null +++ b/tests/transformers/models/test_single_subfunction.py @@ -0,0 +1,94 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import onnx +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils.device_utils import get_available_device_id + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +def get_function(onnx_path): + """Check if ONNX model contains QEffGPT2Block function definition.""" + model = onnx.load(onnx_path, load_external_data=False) + function_names = [f.name for f in model.functions] + return function_names + + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + + functions_names = get_function(with_sub_func_onnx) + print(functions_names) + + keywords = ["DecoderLayer", "Block", "Layer"] + filtered = [name for name in functions_names if any(key in name for key in keywords)] + + if len(filtered) > 1: + raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True)