Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a374683
Added changes for single subfunction signature
abhishek-singh591 Feb 26, 2026
795c21e
Added changes for single subfunction signature
abhishek-singh591 Feb 26, 2026
3345a16
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Feb 26, 2026
ce1fe98
Fixed lint error
abhishek-singh591 Feb 26, 2026
1f74ad5
Added VLMs changes
abhishek-singh591 Feb 27, 2026
ab993fb
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Feb 28, 2026
d2a4211
Update modeling_gpt_oss.py
abhishek-singh591 Feb 28, 2026
e9560d4
Fixed lint error
abhishek-singh591 Mar 6, 2026
41804f7
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Mar 6, 2026
cfbea89
Made minor fix
abhishek-singh591 Mar 6, 2026
419f450
made minor fixes
abhishek-singh591 Mar 6, 2026
111fdaa
lint
abhishek-singh591 Mar 6, 2026
e491d26
Merge remote-tracking branch 'origin/main' into rope_fix
abhishek-singh591 Mar 18, 2026
11acaf3
rebased
abhishek-singh591 Mar 18, 2026
6320cb0
made minor fixes
abhishek-singh591 Mar 18, 2026
6a8647d
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Mar 18, 2026
8217229
made minor fixes
abhishek-singh591 Mar 22, 2026
fdd545f
lint
abhishek-singh591 Mar 22, 2026
10b0ead
Update Jenkinsfile
abhishek-singh591 Mar 23, 2026
333e196
added few fix
abhishek-singh591 Mar 23, 2026
5333644
updated jenkins file
abhishek-singh591 Mar 24, 2026
9b00cf3
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Mar 24, 2026
270b3cd
changed test file
abhishek-singh591 Mar 25, 2026
ae1bf24
Merge branch 'quic:main' into rope_fix
abhishek-singh591 Mar 25, 2026
82928bb
lint
abhishek-singh591 Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ class QEffFalconAttention(FalconAttention):
- add new args position idx for the cache_kwargs for kv retention
"""

def __qeff_init__(self):
self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -125,6 +122,7 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb: Optional[object] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
Expand All @@ -138,7 +136,7 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
cos, sin = rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)

if layer_past is not None:
Expand Down Expand Up @@ -184,6 +182,7 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb=None,
**kwargs,
):
residual = hidden_states
Expand All @@ -208,6 +207,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
rotary_emb=rotary_emb,
)

if not self.config.new_decoder_architecture:
Expand Down Expand Up @@ -305,6 +305,8 @@ def forward(
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

rotary_emb = QEffFalconRotaryEmbedding(config=self.config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move out of forward method and put it in qeff_init of QEffFalconModel class?


for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand All @@ -322,6 +324,7 @@ def forward(
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
rotary_emb=rotary_emb,
)

hidden_states = outputs[0]
Expand Down
10 changes: 6 additions & 4 deletions QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ class QEffGemmaAttention(GemmaAttention):
- add new args cache idx for the kv retention
"""

def __qeff_init__(self):
self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -140,6 +137,7 @@ def forward(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb: Optional[object] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -150,7 +148,7 @@ def forward(
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
Expand Down Expand Up @@ -194,6 +192,7 @@ def forward(
batch_index: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb=None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -223,6 +222,7 @@ def forward(
batch_index=batch_index,
use_cache=use_cache,
cache_position=cache_position,
rotary_emb=rotary_emb,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -297,6 +297,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None

rotary_emb = QEffGemmaRotaryEmbedding(config=self.config)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -310,6 +311,7 @@ def forward(
batch_index=batch_index,
use_cache=use_cache,
cache_position=cache_position,
rotary_emb=rotary_emb,
**kwargs,
)

Expand Down
10 changes: 6 additions & 4 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ class QEffGemma2Attention(Gemma2Attention):
- add new args cache idx for the kv retention
"""

def __qeff_init__(self):
self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -147,6 +144,7 @@ def forward(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb: Optional[object] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -157,7 +155,7 @@ def forward(
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
Expand Down Expand Up @@ -208,6 +206,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb=None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -241,6 +240,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
rotary_emb=rotary_emb,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
Expand Down Expand Up @@ -341,6 +341,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

rotary_emb = QEffGemma2RotaryEmbedding(config=self.config)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -355,6 +356,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
rotary_emb=rotary_emb,
**kwargs,
)

Expand Down
52 changes: 29 additions & 23 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,6 @@ def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None):
# Set the init in the module mapping pytorch transforms
self.__qeff_init__()

def __qeff_init__(self):
self.rotary_emb = QEffGemma3RotaryEmbedding(
self.head_dim,
self.config,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
)

config = copy.deepcopy(self.config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default", "factor": 1.0}
self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern)
self.window = self.config.sliding_window if self.is_local else None

self.rotary_emb_local = QEffGemma3RotaryEmbedding(
self.head_dim,
config,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -218,6 +197,8 @@ def forward(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
rotary_emb: Optional[object] = None,
rotary_emb_local: Optional[object] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Expand All @@ -240,9 +221,9 @@ def forward(
"with a layer index."
)
if self.is_sliding:
cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings)
cos, sin = rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings)
else:
cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings)
cos, sin = rotary_emb(value_states, seq_len=self.config.max_position_embeddings)

query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
Expand Down Expand Up @@ -308,6 +289,8 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
rotary_emb=None,
rotary_emb_local=None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down Expand Up @@ -336,6 +319,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
rotary_emb=rotary_emb,
rotary_emb_local=rotary_emb_local,
**kwargs,
)

Expand Down Expand Up @@ -432,6 +417,25 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

rotary_emb = QEffGemma3RotaryEmbedding(
self.layers[0].self_attn.head_dim,
self.config,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
)

config = copy.deepcopy(self.config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default", "factor": 1.0}

rotary_emb_local = QEffGemma3RotaryEmbedding(
self.layers[0].self_attn.head_dim,
config,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -447,6 +451,8 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
rotary_emb=rotary_emb,
rotary_emb_local=rotary_emb_local,
**flash_attn_kwargs,
)

Expand Down
Loading
Loading