Skip to content

Commit a656ff1

Browse files
committed
Fix access rope_parameters for transformers>=5
1 parent 9360a96 commit a656ff1

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ def patched__compute_dynamic_ntk_parameters(
4848
max_position_embeddings = rope_kwargs["max_position_embeddings"]
4949
factor = rope_kwargs["factor"]
5050
elif config is not None:
51-
base = config.rope_theta
52-
partial_rotary_factor = (
53-
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
54-
)
51+
if hasattr(config, "rope_theta"):
52+
# transformers<5
53+
base = config.rope_theta
54+
partial_rotary_factor = (
55+
config.partial_rotary_factor
56+
if hasattr(config, "partial_rotary_factor")
57+
else 1.0
58+
)
59+
factor = config.rope_scaling["factor"]
60+
else:
61+
print("-----")
62+
print(config)
63+
base = config.rope_parameters["rope_theta"]
64+
partial_rotary_factor = config.rope_parameters["partial_rotary_factor"]
65+
factor = config.rope_parameters["factor"]
5566
head_dim = getattr(
5667
config, "head_dim", config.hidden_size // config.num_attention_heads
5768
)
5869
dim = int(head_dim * partial_rotary_factor)
5970
max_position_embeddings = config.max_position_embeddings
60-
factor = config.rope_scaling["factor"]
6171

6272
attention_factor = 1.0 # Unused in this type of RoPE
6373

0 commit comments

Comments
 (0)