From a374683b0a19b83fcb923e662206c6fa80e43cde Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Thu, 26 Feb 2026 19:32:42 +0000 Subject: [PATCH 01/18] Added changes for single subfunction signature Signed-off-by: abhishek-singh591 --- .../models/falcon/modeling_falcon.py | 12 ++- .../models/gemma/modeling_gemma.py | 11 ++- .../models/gemma2/modeling_gemma2.py | 10 +- .../models/gemma3/modeling_gemma3.py | 39 +++++--- .../transformers/models/gpt2/modeling_gpt2.py | 1 + .../models/gpt_oss/modeling_gpt_oss.py | 23 +++-- .../models/granite/modeling_granite.py | 10 +- .../models/granitemoe/modeling_granitemoe.py | 12 ++- .../models/llama/modeling_llama.py | 11 ++- .../llama_swiftkv/modeling_llama_swiftkv.py | 9 +- .../models/mistral/modeling_mistral.py | 12 ++- .../models/mixtral_moe/modeling_mixtral.py | 11 ++- .../models/olmo2/modeling_olmo2.py | 11 ++- .../transformers/models/phi3/modeling_phi3.py | 11 ++- .../models/qwen2/modeling_qwen2.py | 11 ++- .../models/qwen3/modeling_qwen3.py | 11 ++- .../models/qwen3_moe/modeling_qwen3_moe.py | 10 +- .../models/test_single_subfunction.py | 95 +++++++++++++++++++ 18 files changed, 230 insertions(+), 80 deletions(-) create mode 100644 tests/transformers/models/test_single_subfunction.py diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4ebb2fb96..a0319a144 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -108,9 +108,6 @@ class QEffFalconAttention(FalconAttention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -125,6 +122,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -138,7 +136,7 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: @@ -184,6 +182,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ): residual = hidden_states @@ -208,6 +207,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, + rotary_emb=rotary_emb, ) if not self.config.new_decoder_architecture: @@ -304,6 +304,8 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + + rotary_emb = QEffFalconRotaryEmbedding(config=self.config) for i, block in enumerate(self.h): if output_hidden_states: @@ -322,6 +324,7 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = outputs[0] @@ -338,6 +341,7 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 260d1857a..d1e25c9e2 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -128,9 +128,6 @@ class QEffGemmaAttention(GemmaAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -140,6 +137,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -150,7 +148,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -194,6 +192,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -223,6 +222,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -322,10 +322,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, + rotary_emb=rotary_emb, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 6dee8c85d..640065587 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -135,9 +135,6 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -147,6 +144,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -157,7 +155,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +206,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +240,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -341,6 +341,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -355,6 +356,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index f98bae225..3fb4a8e60 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -188,12 +188,12 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): self.__qeff_init__() def __qeff_init__(self): - self.rotary_emb = QEffGemma3RotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.config.max_position_embeddings, - base=self.config.rope_theta, - ) + # self.rotary_emb = QEffGemma3RotaryEmbedding( + # self.head_dim, + # self.config, + # max_position_embeddings=self.config.max_position_embeddings, + # base=self.config.rope_theta, + # ) config = copy.deepcopy(self.config) config.rope_theta = config.rope_local_base_freq @@ -201,12 +201,12 @@ def __qeff_init__(self): self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) self.window = self.config.sliding_window if self.is_local else None - self.rotary_emb_local = QEffGemma3RotaryEmbedding( - self.head_dim, - config, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + # self.rotary_emb_local = QEffGemma3RotaryEmbedding( + # self.head_dim, + # config, + # max_position_embeddings=config.max_position_embeddings, + # base=config.rope_theta, + # ) def forward( self, @@ -218,6 +218,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -240,9 +241,9 @@ def forward( "with a layer index." ) if self.is_sliding: - cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) else: - cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb(value_states, seq_len=self.config.max_position_embeddings) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -308,6 +309,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, last_cache_position: int = 0, + rotary_emb=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -336,6 +338,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) @@ -432,6 +435,13 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffGemma3RotaryEmbedding( + self.head_dim, + self.config, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -447,6 +457,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, last_cache_position=last_cache_position, + rotary_emb=rotary_emb, **flash_attn_kwargs, ) diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index 7de674cce..74f1894ce 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -329,6 +329,7 @@ def forward( # Fix position_ids to not use -1 position_embeds = self.wpe(torch.where(position_ids == -1, 0, position_ids)) hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 96ea8055c..758271a25 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -737,8 +737,6 @@ def opt_eager_attention_forward_blocked( class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) def forward( self, @@ -750,6 +748,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -760,7 +759,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -819,9 +818,6 @@ def forward( class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -832,6 +828,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -842,7 +839,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -900,9 +897,6 @@ def forward( class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -914,6 +908,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + rotary_emb: Optional[object] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -924,7 +919,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + cos, sin = rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -981,6 +976,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC sliding_mask=None, + rotary_emb=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states @@ -997,6 +993,7 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -1166,7 +1163,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1182,6 +1180,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 8a32c52ef..5832de1f0 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -121,8 +121,6 @@ def eager_attention_forward( class QEffGraniteAttention(GraniteAttention): - def __qeff_init__(self): - self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) def forward( self, @@ -133,6 +131,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -143,7 +142,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -192,6 +191,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + rotary_emb=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -230,6 +230,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -301,6 +302,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -316,6 +319,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 2fa7305c0..6c0fad39e 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -111,9 +111,6 @@ def qeff_apply_rotary_pos_emb( class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -126,6 +123,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -138,7 +136,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -214,6 +212,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -255,6 +254,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) @@ -340,6 +340,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -356,6 +358,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) else: layer_outputs = decoder_layer( @@ -368,6 +371,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 57bccdb1b..11f72a435 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -198,9 +198,6 @@ def eager_attention_forward_blockedKV( class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +209,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -228,7 +226,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -287,6 +285,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -303,6 +302,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -367,6 +367,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -380,6 +382,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index e219d5e03..b2301ef76 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -82,8 +82,6 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) - def forward( self, hidden_states: torch.Tensor, @@ -92,6 +90,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() q_len = 1 # as we always run this for single token @@ -113,7 +112,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( query_states, torch.empty_like(query_states), cos, sin, position_ids @@ -162,6 +161,7 @@ def forward( comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, + rotary_emb=None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -174,6 +174,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -335,6 +336,7 @@ def forward( hidden_states = inputs_embeds next_decoder_cache = None + rotary_emb = QEffLlamaRotaryEmbedding(config=config) for layer_idx in range(self.config.num_key_value_layers): layer = self.layers[layer_idx] @@ -347,6 +349,7 @@ def forward( batch_index=batch_index, output_attentions=False, use_cache=True, + rotary_emb=rotary_emb, ) bsz, q_len, _ = hidden_states.size() diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 47107384e..17e66c414 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -131,9 +131,6 @@ class QEffMistralAttention(MistralAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -146,6 +143,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, # kept here for BC + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -160,7 +158,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -205,6 +203,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -236,6 +235,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -313,7 +313,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + + rotary_emb = QEffMistralRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -328,6 +329,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 680c839ae..9eed10c82 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -128,9 +128,6 @@ def eager_attention_forward( class QEffMixtralAttention(MixtralAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -139,6 +136,7 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -156,7 +154,7 @@ def forward( "with a layer index." ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -265,6 +263,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -301,6 +300,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -381,6 +381,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + + rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: @@ -397,6 +399,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index c79ad7fae..ce8b0fb03 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -123,9 +123,6 @@ def eager_attention_forward( class QEffOlmo2Attention(Olmo2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -135,6 +132,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -151,7 +149,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -198,6 +196,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -213,6 +212,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -283,6 +283,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -297,6 +299,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b48ab2897..3566782aa 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -129,9 +129,6 @@ class QEffPhi3Attention(Phi3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -142,6 +139,7 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -158,7 +156,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -207,6 +205,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -244,6 +243,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, **kwargs, ) @@ -310,6 +310,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + + rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -324,6 +326,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 841df6526..8eeeea936 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -141,9 +141,6 @@ class QEffQwen2Attention(Qwen2Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -153,6 +150,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -163,7 +161,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +206,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -240,6 +239,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -310,6 +310,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + + rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: @@ -324,6 +326,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ccc4bbac2..636b0b2b3 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -142,9 +142,6 @@ class QEffQwen3Attention(Qwen3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -154,6 +151,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -164,7 +162,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -209,6 +207,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +240,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -311,6 +311,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + + rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: @@ -325,6 +327,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index d44668c56..6e3cfc0a6 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -192,8 +192,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class QEffQwen3MoeAttention(Qwen3MoeAttention): - def __qeff_init__(self): - self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) def forward( self, @@ -204,6 +202,7 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -214,7 +213,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -251,6 +250,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -283,6 +283,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -340,6 +341,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -353,6 +356,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py new file mode 100644 index 000000000..527c54b28 --- /dev/null +++ b/tests/transformers/models/test_single_subfunction.py @@ -0,0 +1,95 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from collections import Counter + +import onnx +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils.device_utils import get_available_device_id + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("falcon", 256, 2, 4, 128, 512, 127, {}), + # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +def get_function(onnx_path): + """Check if ONNX model contains QEffGPT2Block function definition.""" + model = onnx.load(onnx_path, load_external_data=False) + function_names = [f.name for f in model.functions] + return function_names + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + # tokenizer = AutoTokenizer.from_pretrained(config.model_type) + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + tmp_path = "/home/abhishek/.cache/qeff_models/temp_onnx" + # Export with subfunctions enabled + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + + print(f"{config.model_type} is going on...") + # Verify that the model with subfunctions has QEffGPT2Block function definition + functions_names = get_function(with_sub_func_onnx) + import pdb; pdb.set_trace() + if len(functions_names) != 12: + raise AssertionError( + f"function definition, but found {len(functions_names)} functions: {functions_names}" + ) + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) \ No newline at end of file From 795c21e690dd023b29bc396fe1b4f5418c83275e Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Thu, 26 Feb 2026 19:33:03 +0000 Subject: [PATCH 02/18] Added changes for single subfunction signature Signed-off-by: abhishek-singh591 --- tests/transformers/models/test_single_subfunction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 527c54b28..54431c8b7 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -83,7 +83,6 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): print(f"{config.model_type} is going on...") # Verify that the model with subfunctions has QEffGPT2Block function definition functions_names = get_function(with_sub_func_onnx) - import pdb; pdb.set_trace() if len(functions_names) != 12: raise AssertionError( f"function definition, but found {len(functions_names)} functions: {functions_names}" From ce1fe98c22db53863fa7e4068475e70dfdb3266a Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Thu, 26 Feb 2026 19:44:29 +0000 Subject: [PATCH 03/18] Fixed lint error Signed-off-by: Abhishek Kumar Singh --- .../models/falcon/modeling_falcon.py | 3 +- .../models/gemma3/modeling_gemma3.py | 39 +++++++++---------- .../transformers/models/gpt2/modeling_gpt2.py | 1 - .../models/gpt_oss/modeling_gpt_oss.py | 3 +- .../models/granite/modeling_granite.py | 3 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../llama_swiftkv/modeling_llama_swiftkv.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral_moe/modeling_mixtral.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- .../transformers/models/phi3/modeling_phi3.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen3/modeling_qwen3.py | 2 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 3 +- .../models/test_single_subfunction.py | 10 ++--- 15 files changed, 34 insertions(+), 44 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index a0319a144..9647d5afe 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -304,7 +304,7 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - + rotary_emb = QEffFalconRotaryEmbedding(config=self.config) for i, block in enumerate(self.h): @@ -341,7 +341,6 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 3fb4a8e60..e9c19d45d 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -187,27 +187,6 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): # Set the init in the module mapping pytorch transforms self.__qeff_init__() - def __qeff_init__(self): - # self.rotary_emb = QEffGemma3RotaryEmbedding( - # self.head_dim, - # self.config, - # max_position_embeddings=self.config.max_position_embeddings, - # base=self.config.rope_theta, - # ) - - config = copy.deepcopy(self.config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default", "factor": 1.0} - self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) - self.window = self.config.sliding_window if self.is_local else None - - # self.rotary_emb_local = QEffGemma3RotaryEmbedding( - # self.head_dim, - # config, - # max_position_embeddings=config.max_position_embeddings, - # base=config.rope_theta, - # ) - def forward( self, hidden_states: torch.Tensor, @@ -219,6 +198,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, rotary_emb: Optional[object] = None, + rotary_emb_local: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,6 +290,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, last_cache_position: int = 0, rotary_emb=None, + rotary_emb_local=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -339,6 +320,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **kwargs, ) @@ -442,6 +424,20 @@ def forward( max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default", "factor": 1.0} + self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) + self.window = self.config.sliding_window if self.is_local else None + + rotary_emb_local = QEffGemma3RotaryEmbedding( + self.head_dim, + config, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -458,6 +454,7 @@ def forward( cache_position=cache_position, last_cache_position=last_cache_position, rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **flash_attn_kwargs, ) diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index 74f1894ce..7de674cce 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -329,7 +329,6 @@ def forward( # Fix position_ids to not use -1 position_embeds = self.wpe(torch.where(position_ids == -1, 0, position_ids)) hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6439b9221..efaffd1a9 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -737,7 +737,6 @@ def opt_eager_attention_forward_blocked( class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def forward( self, hidden_states: torch.Tensor, @@ -1168,7 +1167,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 5832de1f0..bd042f145 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -121,7 +121,6 @@ def eager_attention_forward( class QEffGraniteAttention(GraniteAttention): - def forward( self, hidden_states: torch.Tensor, @@ -302,7 +301,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c12fa980c..cc1e4449d 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -340,7 +340,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index b2301ef76..f6df495a0 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -336,7 +336,7 @@ def forward( hidden_states = inputs_embeds next_decoder_cache = None - rotary_emb = QEffLlamaRotaryEmbedding(config=config) + rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) for layer_idx in range(self.config.num_key_value_layers): layer = self.layers[layer_idx] diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 17e66c414..d8e48fbf7 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -313,7 +313,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffMistralRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 9eed10c82..abe3f4c5a 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -381,7 +381,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - + rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) for decoder_layer in self.layers: diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index ce8b0fb03..889de474e 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -283,7 +283,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 3566782aa..a470d6e79 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -310,7 +310,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - + rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 8eeeea936..eadc53b3f 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -310,7 +310,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - + rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) for decoder_layer in self.layers: diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 636b0b2b3..1531149db 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -311,7 +311,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - + rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) for decoder_layer in self.layers: diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6e3cfc0a6..3c2ca0f2f 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -192,7 +192,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class QEffQwen3MoeAttention(Qwen3MoeAttention): - def forward( self, hidden_states: torch.Tensor, @@ -283,7 +282,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - rotary_emb=rotary_emb, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 54431c8b7..66eee49a7 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- -from collections import Counter import onnx import pytest @@ -70,6 +69,7 @@ def get_function(onnx_path): function_names = [f.name for f in model.functions] return function_names + @pytest.mark.on_qaic @pytest.mark.feature @pytest.mark.parametrize("config", configs, ids=config_ids) @@ -84,11 +84,9 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): # Verify that the model with subfunctions has QEffGPT2Block function definition functions_names = get_function(with_sub_func_onnx) if len(functions_names) != 12: - raise AssertionError( - f"function definition, but found {len(functions_names)} functions: {functions_names}" - ) - + raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") + if not get_available_device_id(): pytest.skip("No available devices to run model on Cloud AI 100") compile_params = {"prefill_seq_len": 8, "ctx_len": 16} - model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) \ No newline at end of file + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) From 1f74ad59c096a58d715e1f468e15352c6650ce50 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Fri, 27 Feb 2026 07:49:20 +0000 Subject: [PATCH 04/18] Added VLMs changes Signed-off-by: Abhishek Kumar Singh --- .../models/mllama/modeling_mllama.py | 12 ++++-- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 11 ++++-- .../models/test_single_subfunction.py | 38 ++++++++++--------- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 3cba022b4..1a9f67d32 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -241,9 +241,6 @@ class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -255,6 +252,7 @@ def forward( position_embeddings: torch.Tensor = None, use_cache: bool = False, cache_position=None, + rotary_emb: Optional[object] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -276,7 +274,7 @@ def forward( ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -326,6 +324,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + rotary_emb: Optional[object] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -361,6 +360,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + rotary_emb=rotary_emb, ) hidden_states = residual + hidden_states @@ -465,6 +465,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -477,6 +478,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -646,6 +648,7 @@ def forward( # embed positions hidden_states = inputs_embeds + rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states @@ -676,6 +679,7 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d6bfbda81..6e375598b 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -564,9 +564,6 @@ class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): and "Generating Long Sequences with Sparse Transformers". """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -579,6 +576,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + rotary_emb: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -595,7 +593,7 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] @@ -661,6 +659,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -701,6 +700,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -771,6 +771,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -785,6 +787,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 66eee49a7..9205d0ba9 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -17,23 +17,23 @@ configs = [ ("gpt2", 256, 2, 4, 128, 512, 127, {}), - # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - # ("falcon", 256, 2, 4, 128, 512, 127, {}), - # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mpt", 256, 2, 4, 128, 512, 127, {}), - # ("phi", 256, 2, 4, 128, 512, 127, {}), - # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -76,13 +76,15 @@ def get_function(onnx_path): def test_subfunction_vs_nonsubfunction(config, tmp_path): # tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) - tmp_path = "/home/abhishek/.cache/qeff_models/temp_onnx" + tmp_path = "/home/abhishek/rope_fix/graph_with_change" # Export with subfunctions enabled with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) print(f"{config.model_type} is going on...") # Verify that the model with subfunctions has QEffGPT2Block function definition + functions_names = get_function(with_sub_func_onnx) + print(functions_names) if len(functions_names) != 12: raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") From d2a4211510a26b6813b022f8b34a168534401fe6 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Sat, 28 Feb 2026 17:38:30 +0530 Subject: [PATCH 05/18] Update modeling_gpt_oss.py Signed-off-by: Abhishek Kumar Singh --- QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index efaffd1a9..0cb8e51fd 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -1074,7 +1074,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1089,6 +1090,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = layer_outputs[0] From e9560d47c008b9e38e253abeb41dfdb9650430a2 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Fri, 6 Mar 2026 04:57:06 +0000 Subject: [PATCH 06/18] Fixed lint error Signed-off-by: Abhishek Kumar Singh --- QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 0cb8e51fd..fd8393713 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -1074,7 +1074,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - + rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) for decoder_layer in self.layers: if output_hidden_states: From cfbea8915d2059fc5068e73a16f58ce85553e674 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Fri, 6 Mar 2026 11:22:20 +0530 Subject: [PATCH 07/18] Made minor fix Signed-off-by: Abhishek Kumar Singh From 419f450eb727c30b729686e097483eda804243d3 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Fri, 6 Mar 2026 07:13:08 +0000 Subject: [PATCH 08/18] made minor fixes Signed-off-by: Abhishek Kumar Singh --- .../transformers/models/mistral/modeling_mistral.py | 1 + .../transformers/models/test_single_subfunction.py | 13 ++++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index d8e48fbf7..450aa96e4 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -315,6 +315,7 @@ def forward( all_self_attns = () if output_attentions else None rotary_emb = QEffMistralRotaryEmbedding(config=self.config) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 9205d0ba9..2c50d7c04 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -17,23 +17,23 @@ configs = [ ("gpt2", 256, 2, 4, 128, 512, 127, {}), - ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), ("falcon", 256, 2, 4, 128, 512, 127, {}), ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("mpt", 256, 2, 4, 128, 512, 127, {}), - ("phi", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -76,7 +76,6 @@ def get_function(onnx_path): def test_subfunction_vs_nonsubfunction(config, tmp_path): # tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) - tmp_path = "/home/abhishek/rope_fix/graph_with_change" # Export with subfunctions enabled with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) From 111fdaa2f0ec9a70f0a475413c95ade11deae7e4 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Fri, 6 Mar 2026 07:15:56 +0000 Subject: [PATCH 09/18] lint Signed-off-by: Abhishek Kumar Singh --- scripts/git_workflow/pr_report.py | 55 +++++++++++++++++-------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/scripts/git_workflow/pr_report.py b/scripts/git_workflow/pr_report.py index 388d776e2..14a5d3aea 100644 --- a/scripts/git_workflow/pr_report.py +++ b/scripts/git_workflow/pr_report.py @@ -265,24 +265,36 @@ def generate_pie_chart_svg(author_counts): # 15-colour palette; cycles if there are more authors colors = [ - "#4a90d9", "#e74c3c", "#2ecc71", "#f39c12", "#9b59b6", - "#1abc9c", "#e67e22", "#3498db", "#e91e63", "#00bcd4", - "#ff5722", "#607d8b", "#795548", "#9c27b0", "#4caf50", + "#4a90d9", + "#e74c3c", + "#2ecc71", + "#f39c12", + "#9b59b6", + "#1abc9c", + "#e67e22", + "#3498db", + "#e91e63", + "#00bcd4", + "#ff5722", + "#607d8b", + "#795548", + "#9c27b0", + "#4caf50", ] - cx, cy, r = 190, 190, 160 # pie centre and radius - legend_x = cx * 2 + 30 # legend column starts here - row_h = 22 # legend row height - svg_w = legend_x + 260 # total SVG width - svg_h = max(cy * 2, len(items) * row_h + 50) # total SVG height + cx, cy, r = 190, 190, 160 # pie centre and radius + legend_x = cx * 2 + 30 # legend column starts here + row_h = 22 # legend row height + svg_w = legend_x + 260 # total SVG width + svg_h = max(cy * 2, len(items) * row_h + 50) # total SVG height # ── Build slice paths ──────────────────────────────────────────────────── paths_svg = "" legend_svg = "" - start_angle = -math.pi / 2 # begin at 12 o'clock + start_angle = -math.pi / 2 # begin at 12 o'clock for i, (author, count) in enumerate(items): - angle = 2 * math.pi * count / total + angle = 2 * math.pi * count / total end_angle = start_angle + angle x1 = cx + r * math.cos(start_angle) @@ -291,20 +303,16 @@ def generate_pie_chart_svg(author_counts): y2 = cy + r * math.sin(end_angle) large_arc = 1 if angle > math.pi else 0 - color = colors[i % len(colors)] - pct = count / total * 100 + color = colors[i % len(colors)] + pct = count / total * 100 # SVG arc path: move to centre → line to arc start → arc → close - path = ( - f"M {cx},{cy} " - f"L {x1:.2f},{y1:.2f} " - f"A {r},{r} 0 {large_arc},1 {x2:.2f},{y2:.2f} Z" - ) + path = f"M {cx},{cy} L {x1:.2f},{y1:.2f} A {r},{r} 0 {large_arc},1 {x2:.2f},{y2:.2f} Z" paths_svg += ( f' \n' - f' {author}: {count} PR{"s" if count != 1 else ""} ({pct:.1f}%)\n' - f' \n' + f" {author}: {count} PR{'s' if count != 1 else ''} ({pct:.1f}%)\n" + f" \n" ) # Legend row @@ -314,8 +322,8 @@ def generate_pie_chart_svg(author_counts): f'fill="{color}" rx="2"/>\n' f' ' - f'{author} {count} PR{"s" if count != 1 else ""} ({pct:.1f}%)' - f'\n' + f"{author} {count} PR{'s' if count != 1 else ''} ({pct:.1f}%)" + f"\n" ) start_angle = end_angle @@ -329,15 +337,14 @@ def generate_pie_chart_svg(author_counts): # Chart title f' ' - f'PR Distribution by Author (Total: {total})\n' + f"PR Distribution by Author (Total: {total})\n" # Slices + paths_svg # Legend header + f' Author\n' # Legend rows - + legend_svg - + '\n\n' + + legend_svg + "\n\n" ) return svg From 11acaf35c989b373880d0ffa290f59c911d37cf7 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Wed, 18 Mar 2026 10:49:48 +0000 Subject: [PATCH 10/18] rebased Signed-off-by: Abhishek Kumar Singh --- .../models/test_single_subfunction.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 2c50d7c04..0c57f8a0a 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -18,21 +18,21 @@ configs = [ ("gpt2", 256, 2, 4, 128, 512, 127, {}), # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("falcon", 256, 2, 4, 128, 512, 127, {}), - ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("falcon", 256, 2, 4, 128, 512, 127, {}), + # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("mpt", 256, 2, 4, 128, 512, 127, {}), # ("phi", 256, 2, 4, 128, 512, 127, {}), - ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] @@ -77,6 +77,7 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): # tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) # Export with subfunctions enabled + tmp_path = "/home/abhishek/rope_fix/graph_with_change" with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) print(f"{config.model_type} is going on...") From 6320cb0da265785c1f241c781cae2cb0275470ae Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Wed, 18 Mar 2026 18:38:21 +0000 Subject: [PATCH 11/18] made minor fixes Signed-off-by: Abhishek Kumar Singh --- .../models/test_single_subfunction.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 0c57f8a0a..2f79ae469 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -18,21 +18,21 @@ configs = [ ("gpt2", 256, 2, 4, 128, 512, 127, {}), # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - # ("falcon", 256, 2, 4, 128, 512, 127, {}), - # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), # ("phi", 256, 2, 4, 128, 512, 127, {}), - # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] From 8217229855852165f095e549c90ea82e12251592 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Sun, 22 Mar 2026 04:47:00 +0000 Subject: [PATCH 12/18] made minor fixes Signed-off-by: abhishek-singh591 --- .../models/gemma/modeling_gemma.py | 5 +- .../models/gemma3/modeling_gemma3.py | 54 ++++++++----------- .../causallm/example_pytorch_transforms.py | 12 ++--- .../models/test_single_subfunction.py | 8 +-- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index d1e25c9e2..f554e7d8c 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -297,6 +297,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -310,6 +311,7 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, **kwargs, ) @@ -322,13 +324,10 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, - rotary_emb=rotary_emb, ) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index e9c19d45d..f98bae225 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -187,6 +187,27 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): # Set the init in the module mapping pytorch transforms self.__qeff_init__() + def __qeff_init__(self): + self.rotary_emb = QEffGemma3RotaryEmbedding( + self.head_dim, + self.config, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default", "factor": 1.0} + self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) + self.window = self.config.sliding_window if self.is_local else None + + self.rotary_emb_local = QEffGemma3RotaryEmbedding( + self.head_dim, + config, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, @@ -197,8 +218,6 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - rotary_emb: Optional[object] = None, - rotary_emb_local: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -221,9 +240,9 @@ def forward( "with a layer index." ) if self.is_sliding: - cos, sin = rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) else: - cos, sin = rotary_emb(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -289,8 +308,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, last_cache_position: int = 0, - rotary_emb=None, - rotary_emb_local=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -319,8 +336,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - rotary_emb=rotary_emb, - rotary_emb_local=rotary_emb_local, **kwargs, ) @@ -417,27 +432,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - - rotary_emb = QEffGemma3RotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.config.max_position_embeddings, - base=self.config.rope_theta, - ) - - config = copy.deepcopy(self.config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default", "factor": 1.0} - self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) - self.window = self.config.sliding_window if self.is_local else None - - rotary_emb_local = QEffGemma3RotaryEmbedding( - self.head_dim, - config, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -453,8 +447,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, last_cache_position=last_cache_position, - rotary_emb=rotary_emb, - rotary_emb_local=rotary_emb_local, **flash_attn_kwargs, ) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9..503efc12d 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 2f79ae469..968800918 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -34,6 +34,8 @@ # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -74,15 +76,9 @@ def get_function(onnx_path): @pytest.mark.feature @pytest.mark.parametrize("config", configs, ids=config_ids) def test_subfunction_vs_nonsubfunction(config, tmp_path): - # tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) - # Export with subfunctions enabled - tmp_path = "/home/abhishek/rope_fix/graph_with_change" with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) - print(f"{config.model_type} is going on...") - # Verify that the model with subfunctions has QEffGPT2Block function definition - functions_names = get_function(with_sub_func_onnx) print(functions_names) if len(functions_names) != 12: From fdd545ffe5921e120d2db8e934d1392218a959a8 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Sun, 22 Mar 2026 04:55:21 +0000 Subject: [PATCH 13/18] lint Signed-off-by: abhishek-singh591 --- .../causallm/example_pytorch_transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12d..ff62588f9 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, From 10b0ead5c1ed426d251d6cba399567a152204fe2 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Mon, 23 Mar 2026 13:28:43 +0530 Subject: [PATCH 14/18] Update Jenkinsfile Signed-off-by: Abhishek Kumar Singh --- scripts/Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index b791f3a31..476991003 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -41,7 +41,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --ignore tests/transformers/models/image_text_to_text -n 4 --junitxml=tests/tests_log1.xml --durations=10 && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --ignore tests/transformers/models/image_text_to_text --ignore tests/unit_test -n 4 --junitxml=tests/tests_log1.xml --durations=10 && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' From 333e196dcb7cf66a59b382c76b2cb673f0f58d1f Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Mon, 23 Mar 2026 11:54:57 +0000 Subject: [PATCH 15/18] added few fix Signed-off-by: abhishek-singh591 --- .../models/gemma3/modeling_gemma3.py | 52 +++++++++++-------- .../llama_swiftkv/modeling_llama_swiftkv.py | 2 +- .../models/mllama/modeling_mllama.py | 1 - 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index f98bae225..4f8380738 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -187,27 +187,6 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): # Set the init in the module mapping pytorch transforms self.__qeff_init__() - def __qeff_init__(self): - self.rotary_emb = QEffGemma3RotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.config.max_position_embeddings, - base=self.config.rope_theta, - ) - - config = copy.deepcopy(self.config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default", "factor": 1.0} - self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) - self.window = self.config.sliding_window if self.is_local else None - - self.rotary_emb_local = QEffGemma3RotaryEmbedding( - self.head_dim, - config, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - def forward( self, hidden_states: torch.Tensor, @@ -218,6 +197,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + rotary_emb: Optional[object] = None, + rotary_emb_local: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -240,9 +221,9 @@ def forward( "with a layer index." ) if self.is_sliding: - cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) else: - cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings) + cos, sin = rotary_emb(value_states, seq_len=self.config.max_position_embeddings) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -308,6 +289,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, last_cache_position: int = 0, + rotary_emb=None, + rotary_emb_local=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -336,6 +319,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **kwargs, ) @@ -432,6 +417,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + + rotary_emb = QEffGemma3RotaryEmbedding( + self.layers[0].self_attn.head_dim, + self.config, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default", "factor": 1.0} + + rotary_emb_local = QEffGemma3RotaryEmbedding( + self.layers[0].self_attn.head_dim, + config, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -447,6 +451,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, last_cache_position=last_cache_position, + rotary_emb=rotary_emb, + rotary_emb_local=rotary_emb_local, **flash_attn_kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f6df495a0..30ab93d45 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -375,7 +375,7 @@ def forward( ) kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 1a9f67d32..de6996882 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -478,7 +478,6 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, - rotary_emb=rotary_emb, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states From 53336446755d98619a7adc2e6a5557adf51ca01e Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Tue, 24 Mar 2026 03:14:08 +0000 Subject: [PATCH 16/18] updated jenkins file Signed-off-by: abhishek-singh591 --- scripts/Jenkinsfile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 476991003..7ac1f53cc 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -50,7 +50,7 @@ pipeline { } stage('QAIC LLM Tests') { steps { - timeout(time: 120, unit: 'MINUTES') { + timeout(time: 180, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -58,7 +58,7 @@ pipeline { mkdir -p $PWD/Non_qaic_llm && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic_llm && - pytest tests -m '(not cli) and (on_qaic) and (llm_model) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2.xml --durations=10 && + pytest tests -m '(not cli) and (on_qaic) and (llm_model) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2.xml --durations=10 && junitparser merge tests/tests_log2.xml tests/tests_log.xml && deactivate" ''' @@ -75,7 +75,7 @@ pipeline { mkdir -p $PWD/Non_qaic_feature && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic_feature && - pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2_feature.xml --durations=10 && + pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2_feature.xml --durations=10 && junitparser merge tests/tests_log2_feature.xml tests/tests_log.xml && deactivate" ''' @@ -94,7 +94,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic_multimodal && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && - pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log6.xml --durations=10 && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log6.xml --durations=10 && junitparser merge tests/tests_log6.xml tests/tests_log.xml && deactivate" ''' @@ -112,7 +112,7 @@ pipeline { export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic_diffusion && export HF_HUB_CACHE=/huggingface_hub && - pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml --durations=10 && + pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_diffusion.xml --durations=10 && junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml && deactivate" ''' @@ -131,7 +131,7 @@ pipeline { mkdir -p $PWD/cli && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/cli && - pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log3.xml --durations=10 && + pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log3.xml --durations=10 && junitparser merge tests/tests_log3.xml tests/tests_log.xml && deactivate" ''' @@ -209,7 +209,7 @@ pipeline { mkdir -p $PWD/cli_qaic_finetuning && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/cli_qaic_finetuning && - pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --junitxml=tests/tests_log_finetune.xml --durations=10 && + pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_finetune.xml --durations=10 && junitparser merge tests/tests_log_finetune.xml tests/tests_log.xml && deactivate" ''' From 270b3cd53f44d5a64776d60dff33a1659cd28dbc Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 02:58:05 +0000 Subject: [PATCH 17/18] changed test file Signed-off-by: abhishek-singh591 --- tests/transformers/models/test_single_subfunction.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 968800918..73686bd72 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -11,7 +11,6 @@ from transformers import AutoConfig, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM -from QEfficient.utils.device_utils import get_available_device_id torch.manual_seed(42) @@ -81,7 +80,11 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): functions_names = get_function(with_sub_func_onnx) print(functions_names) - if len(functions_names) != 12: + + keywords = ["DecoderLayer", "Block", "Layer"] + filtered = [name for name in functions_names if any(key in name for key in keywords)] + + if len(filtered) > 1: raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") if not get_available_device_id(): From 82928bb15ae738e42e38e0f7f5ac24fe5263b6a4 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 03:05:46 +0000 Subject: [PATCH 18/18] lint Signed-off-by: abhishek-singh591 --- tests/transformers/models/test_single_subfunction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 73686bd72..f17edab65 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils.device_utils import get_available_device_id torch.manual_seed(42)