Skip to content

Commit d19c66b

Browse files
committed
fix rotary patch
1 parent deb39ec commit d19c66b

File tree

1 file changed

+60
-19
lines changed

1 file changed

+60
-19
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def patched__compute_dynamic_ntk_parameters(
10191019
return inv_freq, attention_factor
10201020

10211021

1022-
def _get_rope_init_fn(self) -> Callable:
1022+
def _get_rope_init_fn(self, layer_type=None) -> Callable:
10231023
if hasattr(self, "rope_init_fn"):
10241024
# transformers<=5.0
10251025
rope_init_fn = (
@@ -1030,8 +1030,9 @@ def _get_rope_init_fn(self) -> Callable:
10301030
)
10311031
return rope_init_fn
10321032

1033+
rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
10331034
rope_init_fn = self.compute_default_rope_parameters
1034-
if self.rope_type != "default":
1035+
if rope_type != "default":
10351036
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
10361037
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
10371038
return patched__compute_dynamic_ntk_parameters
@@ -1101,17 +1102,27 @@ def wrapper(self, x, position_ids):
11011102
11021103
"""
11031104

1104-
def longrope_frequency_update(self, position_ids, device):
1105+
def longrope_frequency_update(self, position_ids, device, layer_type=None):
11051106
# It is no use to patch the function after the model is created
11061107
# as rope_init_fn is an attribute set to one function when the model
11071108
# is created and when no patch is applied yet.
11081109
# So we select the patched version here.
1109-
rope_init_fn = _get_rope_init_fn(self)
1110+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
11101111
seq_len = torch.max(position_ids) + 1
11111112
if hasattr(self.config, "original_max_position_embeddings"):
11121113
original_max_position_embeddings = self.config.original_max_position_embeddings
11131114
else:
11141115
original_max_position_embeddings = self.config.max_position_embeddings
1116+
1117+
if layer_type is None:
1118+
# rope_type = self.rope_type
1119+
original_inv_freq = self.original_inv_freq
1120+
prefix = ""
1121+
else:
1122+
# rope_type = self.rope_type[layer_type]
1123+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1124+
prefix = f"{layer_type}_"
1125+
11151126
# At export time, seq_len is unknown.
11161127
long_inv_freq, _ = rope_init_fn(
11171128
self.config, device, seq_len=original_max_position_embeddings + 1
@@ -1126,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device):
11261137
(lambda x, y: y.clone()),
11271138
[long_inv_freq, original_inv_freq],
11281139
)
1129-
self.inv_freq = inv_freq
1140+
setattr(self, f"{prefix}inv_freq", inv_freq)
11301141
# if seq_len > original_max_position_embeddings:
11311142
# self.inv_freq = self.long_inv_freq
11321143
# else:
11331144
# self.inv_freq = self.original_inv_freq
11341145

1135-
def dynamic_frequency_update(self, position_ids, device):
1146+
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
11361147
# constructor:
11371148
# - self.max_seq_len_cached = config.max_position_embeddings
11381149
# - self.original_max_seq_len = config.max_position_embeddings
@@ -1142,7 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device):
11421153
# as rope_init_fn is an attribute set to one function when the model
11431154
# is created and when no patch is applied yet.
11441155
# So we select the patched version here.
1145-
rope_init_fn = _get_rope_init_fn(self)
1156+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
11461157

11471158
# This behaviour is difficult to translate.
11481159
# The sequence always grows.
@@ -1171,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device):
11711182
self.config, device, seq_len=seq_len
11721183
)
11731184

1185+
if layer_type is None:
1186+
# rope_type = self.rope_type
1187+
# max_seq_len_cached = self.max_seq_len_cached
1188+
original_inv_freq = self.original_inv_freq
1189+
prefix = ""
1190+
else:
1191+
# rope_type = self.rope_type[layer_type]
1192+
# max_seq_len_cached = getattr(
1193+
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1194+
# )
1195+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1196+
prefix = f"{layer_type}_"
1197+
11741198
# Second test to translate.
11751199
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
11761200
# But in that case the following condition is a way to restore the original cache.
@@ -1192,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device):
11921216
(lambda x, y: y.clone()),
11931217
[long_inv_freq, original_inv_freq],
11941218
)
1195-
self.inv_freq = inv_freq
1219+
setattr(self, f"{prefix}inv_freq", inv_freq)
11961220

11971221
@wraps(rope_forward)
1198-
def wrapper(self, x, position_ids):
1222+
def wrapper(self, x, position_ids, layer_type=None):
1223+
if layer_type is None:
1224+
if "dynamic" in self.rope_type:
1225+
dynamic_frequency_update(self, position_ids, device=x.device)
1226+
elif self.rope_type == "longrope":
1227+
longrope_frequency_update(self, position_ids, device=x.device)
1228+
return rope_forward(self, x, position_ids)
1229+
11991230
if "dynamic" in self.rope_type:
1200-
dynamic_frequency_update(self, position_ids, device=x.device)
1231+
dynamic_frequency_update(
1232+
self, position_ids, device=x.device, layer_type=layer_type
1233+
)
12011234
elif self.rope_type == "longrope":
1202-
longrope_frequency_update(self, position_ids, device=x.device)
1203-
return rope_forward(self, x, position_ids)
1235+
longrope_frequency_update(
1236+
self, position_ids, device=x.device, layer_type=layer_type
1237+
)
1238+
return rope_forward(self, x, position_ids, layer_type=layer_type)
12041239

12051240
return wrapper
12061241

@@ -1296,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module):
12961331
# @torch.no_grad()
12971332
# PATCHED: the decorator
12981333
@patched_dynamic_rope_update
1299-
def forward(self, x, position_ids):
1334+
def forward(self, x, position_ids, layer_type=None):
1335+
if layer_type is not None:
1336+
# transformers>=5.0
1337+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
1338+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
1339+
else:
1340+
# transformers<5.0
1341+
inv_freq = self.inv_freq
1342+
attention_scaling = self.attention_scaling
1343+
13001344
inv_freq_expanded = (
1301-
self.inv_freq[None, :, None]
1302-
.float()
1303-
.expand(position_ids.shape[0], -1, 1)
1304-
.to(x.device)
1345+
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
13051346
)
13061347
position_ids_expanded = position_ids[:, None, :].float()
13071348

@@ -1313,8 +1354,8 @@ def forward(self, x, position_ids):
13131354
with torch.autocast(device_type=device_type, enabled=False): # Force float32
13141355
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
13151356
emb = torch.cat((freqs, freqs), dim=-1)
1316-
cos = emb.cos() * self.attention_scaling
1317-
sin = emb.sin() * self.attention_scaling
1357+
cos = emb.cos() * attention_scaling
1358+
sin = emb.sin() * attention_scaling
13181359

13191360
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
13201361

0 commit comments

Comments
 (0)