Skip to content

Commit 9dcb936

Browse files
authored
Fixes access rope_parameters for transformers>=5 (#330)
* Fix access rope_parameters for transformers>=5 * doc
1 parent 9360a96 commit 9dcb936

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`330`: fixes access rope_parameters for ``transformers>=5``
78
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
89
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
910
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ def patched__compute_dynamic_ntk_parameters(
4848
max_position_embeddings = rope_kwargs["max_position_embeddings"]
4949
factor = rope_kwargs["factor"]
5050
elif config is not None:
51-
base = config.rope_theta
52-
partial_rotary_factor = (
53-
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
54-
)
51+
if hasattr(config, "rope_theta"):
52+
# transformers<5
53+
base = config.rope_theta
54+
partial_rotary_factor = (
55+
config.partial_rotary_factor
56+
if hasattr(config, "partial_rotary_factor")
57+
else 1.0
58+
)
59+
factor = config.rope_scaling["factor"]
60+
else:
61+
print("-----")
62+
print(config)
63+
base = config.rope_parameters["rope_theta"]
64+
partial_rotary_factor = config.rope_parameters["partial_rotary_factor"]
65+
factor = config.rope_parameters["factor"]
5566
head_dim = getattr(
5667
config, "head_dim", config.hidden_size // config.num_attention_heads
5768
)
5869
dim = int(head_dim * partial_rotary_factor)
5970
max_position_embeddings = config.max_position_embeddings
60-
factor = config.rope_scaling["factor"]
6171

6272
attention_factor = 1.0 # Unused in this type of RoPE
6373

0 commit comments

Comments
 (0)