Skip to content

Commit ecec56d

Browse files
committed
fix version
1 parent 053725e commit ecec56d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -846,14 +846,16 @@ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
846846
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
847847

848848

849-
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
850-
_PATCHES_ = ["forward"]
851-
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
849+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
850+
851+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
852+
_PATCHES_ = ["forward"]
853+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
852854

853855

854-
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
855-
_PATCHES_ = ["forward"]
856-
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
856+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
857+
_PATCHES_ = ["forward"]
858+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
857859

858860

859861
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):

0 commit comments

Comments
 (0)