Skip to content

Commit 895d987

Browse files
Rope Fix for a single subfunction signature (#880)
## Summary This PR introduces the **Rotary Position Embedding (RoPE) fix**, ensuring that models generate a **single unified subfunction signature** during ONNX export. ## Models Status After Applying the Fix ### Models now producing a single subfunction signature _All causal LMs tested in the associated test file are functioning correctly, except those listed below._ ### Models still producing **two different subfunction signatures** The following models continue to emit multiple subfunction signatures and require additional investigation: - [ ] Phi-1 - [ ] StarCoder2 - [ ] CodeGen ### Models with issues **unrelated** to the RoPE fix These models have separate problems that need to be addressed independently: - [ ] Granite-MoE - [ ] GPT-OSS - [ ] Mixtral --------- Signed-off-by: abhishek-singh591 <sabhis@qti.qualcomm.com> Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com> Co-authored-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 32a2235 commit 895d987

File tree

18 files changed

+437
-302
lines changed

18 files changed

+437
-302
lines changed

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
5959
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
6060
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
6161

62-
def forward(self, x, seq_len=None):
63-
# x: [bs, num_attention_heads, seq_len, head_size]
64-
if seq_len > self.max_seq_len_cached:
65-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
66-
67-
return (
68-
self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
69-
self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
70-
)
71-
7262

7363
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
7464
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -108,9 +98,6 @@ class QEffFalconAttention(FalconAttention):
10898
- add new args position idx for the cache_kwargs for kv retention
10999
"""
110100

111-
def __qeff_init__(self):
112-
self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config)
113-
114101
def forward(
115102
self,
116103
hidden_states: torch.Tensor,
@@ -125,6 +112,8 @@ def forward(
125112
use_cache: bool = False,
126113
output_attentions: bool = False,
127114
cache_position: Optional[torch.LongTensor] = None,
115+
cos_cached: Optional[torch.Tensor] = None,
116+
sin_cached: Optional[torch.Tensor] = None,
128117
):
129118
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
130119
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -137,9 +126,8 @@ def forward(
137126
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
138127
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
139128

140-
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
141-
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
142-
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
129+
# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
130+
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids)
143131

144132
if layer_past is not None:
145133
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
@@ -184,6 +172,8 @@ def forward(
184172
use_cache: bool = False,
185173
output_attentions: bool = False,
186174
cache_position: Optional[torch.LongTensor] = None,
175+
sin_cached=None,
176+
cos_cached=None,
187177
**kwargs,
188178
):
189179
residual = hidden_states
@@ -208,6 +198,8 @@ def forward(
208198
use_cache=use_cache,
209199
output_attentions=output_attentions,
210200
cache_position=cache_position,
201+
sin_cached=sin_cached,
202+
cos_cached=cos_cached,
211203
)
212204

213205
if not self.config.new_decoder_architecture:
@@ -245,6 +237,11 @@ class QEffFalconModel(FalconModel):
245237
- update causal attention mask
246238
"""
247239

240+
def __qeff_init__(self):
241+
self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config)
242+
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling)
243+
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling)
244+
248245
def forward(
249246
self,
250247
input_ids: torch.LongTensor = None,
@@ -322,6 +319,8 @@ def forward(
322319
output_attentions=output_attentions,
323320
alibi=alibi,
324321
cache_position=cache_position,
322+
sin_cached=self.sin_cached,
323+
cos_cached=self.cos_cached,
325324
)
326325

327326
hidden_states = outputs[0]

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
5555
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
5656
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
5757

58-
def forward(self, x, seq_len=None):
59-
# x: [bs, num_attention_heads, seq_len, head_size]
60-
if seq_len > self.max_seq_len_cached:
61-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
62-
63-
return (
64-
self.cos_cached[:seq_len].to(dtype=x.dtype),
65-
self.sin_cached[:seq_len].to(dtype=x.dtype),
66-
)
67-
6858

6959
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
7060
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -128,9 +118,6 @@ class QEffGemmaAttention(GemmaAttention):
128118
- add new args cache idx for the kv retention
129119
"""
130120

131-
def __qeff_init__(self):
132-
self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config)
133-
134121
def forward(
135122
self,
136123
hidden_states: torch.Tensor,
@@ -140,6 +127,8 @@ def forward(
140127
comp_ctx_lengths: Optional[torch.LongTensor] = None,
141128
batch_index: Optional[torch.LongTensor] = None,
142129
cache_position: Optional[torch.LongTensor] = None,
130+
cos_cached: Optional[torch.Tensor] = None,
131+
sin_cached: Optional[torch.Tensor] = None,
143132
**kwargs,
144133
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
145134
input_shape = hidden_states.shape[:-1]
@@ -149,9 +138,10 @@ def forward(
149138
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
150139
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
151140

152-
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
153-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
154-
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
141+
# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
142+
query_states, key_states = qeff_apply_rotary_pos_emb(
143+
query_states, key_states, cos_cached, sin_cached, position_ids
144+
)
155145

156146
if past_key_value is not None:
157147
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
@@ -194,6 +184,8 @@ def forward(
194184
batch_index: Optional[torch.LongTensor] = None,
195185
use_cache: Optional[bool] = False,
196186
cache_position: Optional[torch.LongTensor] = None,
187+
sin_cached=None,
188+
cos_cached=None,
197189
**kwargs,
198190
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
199191
"""
@@ -223,6 +215,8 @@ def forward(
223215
batch_index=batch_index,
224216
use_cache=use_cache,
225217
cache_position=cache_position,
218+
sin_cached=sin_cached,
219+
cos_cached=cos_cached,
226220
**kwargs,
227221
)
228222
hidden_states = residual + hidden_states
@@ -243,6 +237,11 @@ class QEffGemmaModel(GemmaModel):
243237
- add new args cache idx for the kv retention
244238
"""
245239

240+
def __qeff_init__(self):
241+
self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config)
242+
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached)
243+
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached)
244+
246245
def forward(
247246
self,
248247
input_ids: torch.LongTensor = None,
@@ -310,6 +309,8 @@ def forward(
310309
batch_index=batch_index,
311310
use_cache=use_cache,
312311
cache_position=cache_position,
312+
sin_cached=self.sin_cached,
313+
cos_cached=self.cos_cached,
313314
**kwargs,
314315
)
315316

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
5858
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
5959
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
6060

61-
def forward(self, x, seq_len=None):
62-
# x: [bs, num_attention_heads, seq_len, head_size]
63-
if seq_len > self.max_seq_len_cached:
64-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
65-
66-
return (
67-
self.cos_cached[:seq_len].to(dtype=x.dtype),
68-
self.sin_cached[:seq_len].to(dtype=x.dtype),
69-
)
70-
7161

7262
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
7363
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -135,9 +125,6 @@ class QEffGemma2Attention(Gemma2Attention):
135125
- add new args cache idx for the kv retention
136126
"""
137127

138-
def __qeff_init__(self):
139-
self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config)
140-
141128
def forward(
142129
self,
143130
hidden_states: torch.Tensor,
@@ -147,6 +134,8 @@ def forward(
147134
comp_ctx_lengths: Optional[torch.LongTensor] = None,
148135
batch_index: Optional[torch.LongTensor] = None,
149136
cache_position: Optional[torch.LongTensor] = None,
137+
cos_cached: Optional[torch.Tensor] = None,
138+
sin_cached: Optional[torch.Tensor] = None,
150139
**kwargs,
151140
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
152141
input_shape = hidden_states.shape[:-1]
@@ -156,15 +145,16 @@ def forward(
156145
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
157146
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
158147

159-
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
160-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
161-
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
148+
# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
149+
query_states, key_states = qeff_apply_rotary_pos_emb(
150+
query_states, key_states, cos_cached, sin_cached, position_ids
151+
)
162152

163153
if past_key_value is not None:
164154
# sin and cos are specific to RoPE models; cache_position needed for the static cache
165155
cache_kwargs = {
166-
"sin": sin,
167-
"cos": cos,
156+
"sin": sin_cached,
157+
"cos": cos_cached,
168158
"batch_index": batch_index,
169159
"position_ids": position_ids,
170160
}
@@ -208,6 +198,8 @@ def forward(
208198
output_attentions: Optional[bool] = False,
209199
use_cache: Optional[bool] = False,
210200
cache_position: Optional[torch.LongTensor] = None,
201+
sin_cached=None,
202+
cos_cached=None,
211203
**kwargs,
212204
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
213205
"""
@@ -241,6 +233,8 @@ def forward(
241233
output_attentions=output_attentions,
242234
use_cache=use_cache,
243235
cache_position=cache_position,
236+
sin_cached=sin_cached,
237+
cos_cached=cos_cached,
244238
**kwargs,
245239
)
246240
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -271,6 +265,11 @@ class QEffGemma2Model(Gemma2Model):
271265
- add new args cache idx for the kv retention
272266
"""
273267

268+
def __qeff_init__(self):
269+
self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config)
270+
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached)
271+
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached)
272+
274273
def forward(
275274
self,
276275
input_ids: torch.LongTensor = None,
@@ -355,6 +354,8 @@ def forward(
355354
output_attentions=output_attentions,
356355
use_cache=use_cache,
357356
cache_position=cache_position,
357+
sin_cached=self.sin_cached,
358+
cos_cached=self.cos_cached,
358359
**kwargs,
359360
)
360361

0 commit comments

Comments
 (0)