Skip to content

Commit fbbf946

Browse files
committed
fix rope
1 parent 29feec4 commit fbbf946

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,25 @@ def patched__compute_dynamic_ntk_parameters(
10191019
return inv_freq, attention_factor
10201020

10211021

1022+
def _get_rope_init_fn(self) -> Callable:
1023+
if hasattr(self, "rope_init_fn"):
1024+
# transformers<=5.0
1025+
rope_init_fn = (
1026+
patched__compute_dynamic_ntk_parameters
1027+
if self.rope_init_fn
1028+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1029+
else self.rope_init_fn
1030+
)
1031+
return rope_init_fn
1032+
1033+
rope_init_fn = self.compute_default_rope_parameters
1034+
if self.rope_type != "default":
1035+
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
1036+
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
1037+
return patched__compute_dynamic_ntk_parameters
1038+
return rope_init_fn
1039+
1040+
10221041
def patched_dynamic_rope_update(rope_forward):
10231042
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10241043
@@ -1087,12 +1106,7 @@ def longrope_frequency_update(self, position_ids, device):
10871106
# as rope_init_fn is an attribute set to one function when the model
10881107
# is created and when no patch is applied yet.
10891108
# So we select the patched version here.
1090-
rope_init_fn = (
1091-
patched__compute_dynamic_ntk_parameters
1092-
if self.rope_init_fn
1093-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1094-
else self.rope_init_fn
1095-
)
1109+
rope_init_fn = _get_rope_init_fn(self)
10961110
seq_len = torch.max(position_ids) + 1
10971111
if hasattr(self.config, "original_max_position_embeddings"):
10981112
original_max_position_embeddings = self.config.original_max_position_embeddings
@@ -1128,12 +1142,7 @@ def dynamic_frequency_update(self, position_ids, device):
11281142
# as rope_init_fn is an attribute set to one function when the model
11291143
# is created and when no patch is applied yet.
11301144
# So we select the patched version here.
1131-
rope_init_fn = (
1132-
patched__compute_dynamic_ntk_parameters
1133-
if self.rope_init_fn
1134-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1135-
else self.rope_init_fn
1136-
)
1145+
rope_init_fn = _get_rope_init_fn(self)
11371146

11381147
# This behaviour is difficult to translate.
11391148
# The sequence always grows.

0 commit comments

Comments
 (0)