From 04686de0dccf3de7d5f937ffbd7503e20357534a Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 13:42:05 +0200 Subject: [PATCH 1/9] Patch LlamaRotaryEmbedding --- .../patches/patch_transformers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 017b1985..12701fab 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -814,10 +814,7 @@ def wrapper(self, x, position_ids): return wrapper -class patched_Phi3RotaryEmbedding(torch.nn.Module): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding - +class common_RotaryEmbedding(torch.nn.Module): @torch.no_grad() @patched_dynamic_rope_update def forward(self, x, position_ids): @@ -843,6 +840,16 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class patched_LlamaRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding + + +class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding + + class patched_IdeficsEmbedding(torch.nn.Module): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding From 87e9693f2af5687ab0d062af1b6ab52ad4aaf461 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 14:16:33 +0200 Subject: [PATCH 2/9] adding more patched rotary --- CHANGELOGS.rst | 1 + .../patches/patch_transformers.py | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 12192a1b..6970b954 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`170`: patches LlamaRotaryEmbedding * :pr:`168`, :pr:`169`: introduces patch_diffusers * :pr:`166`: improves handling of StaticCache * :pr:`165`: support for task text-to-image diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 12701fab..30f138fc 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -840,16 +840,58 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class patched_GemmaRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding + + +class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding + + +class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding + + class patched_LlamaRotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding +class patched_MistralRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding + + +class patched_MixtralRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding + + +class patched_PhiRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding + + class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding +class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = ( + transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding + ) + + +class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.smallm3.modeling_smallm3.SmolLM3RotaryEmbedding + + class patched_IdeficsEmbedding(torch.nn.Module): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding From 5723c43a9f6947847f514b8156cf1d2a8ac1ef90 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 14:24:22 +0200 Subject: [PATCH 3/9] fix typo --- .../torch_export_patches/patches/patch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 30f138fc..51620691 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -889,7 +889,7 @@ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.smallm3.modeling_smallm3.SmolLM3RotaryEmbedding + _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding class patched_IdeficsEmbedding(torch.nn.Module): From 053725e7255df9e66295c0c13695f0d8f8180cd6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 14:31:45 +0200 Subject: [PATCH 4/9] use version --- .../torch_export_patches/patches/patch_transformers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 51620691..86c35c4c 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple +import packaging.version as pv import torch import transformers from transformers.modeling_attn_mask_utils import AttentionMaskConverter @@ -887,9 +888,11 @@ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): ) -class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding +if pv.Version(transformers.__version__) >= pv.Version("4.53"): + + class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding class patched_IdeficsEmbedding(torch.nn.Module): From ecec56d5abfa79740d6b548e7bae9237666b5b36 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 14:45:03 +0200 Subject: [PATCH 5/9] fix version --- .../patches/patch_transformers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 86c35c4c..b577a3c1 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -846,14 +846,16 @@ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding): _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding -class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding +if pv.Version(transformers.__version__) >= pv.Version("4.53"): + + class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding -class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding + class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding class patched_LlamaRotaryEmbedding(common_RotaryEmbedding): From a4d323cfeace5823c6ce2c709e985fa2a3717b68 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 14:47:55 +0200 Subject: [PATCH 6/9] typo --- .../torch_export_patches/patches/patch_transformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index b577a3c1..1f2f25d3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -852,7 +852,6 @@ class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding - class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding From efd7b5eb880b10d70b1a747c3057a896e8938a1a Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 15:00:00 +0200 Subject: [PATCH 7/9] fix issues --- .../patches/patch_transformers.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 1f2f25d3..5d1d0819 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -846,7 +846,7 @@ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding): _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding -if pv.Version(transformers.__version__) >= pv.Version("4.53"): +if pv.Version(transformers.__version__) >= pv.Version("4.52"): class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] @@ -877,19 +877,17 @@ class patched_PhiRotaryEmbedding(common_RotaryEmbedding): _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding -class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding - - -class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = ( - transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding - ) +if pv.Version(transformers.__version__) >= pv.Version("4.52"): + class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding -if pv.Version(transformers.__version__) >= pv.Version("4.53"): + class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = ( + transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding + ) class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] From 10c2a92c2a968b828719a7e1cd76ae5f0b9fa724 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 17:22:52 +0200 Subject: [PATCH 8/9] fix --- .../torch_export_patches/patches/patch_transformers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 5d1d0819..e4b15154 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -889,6 +889,9 @@ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding ) + +if pv.Version(transformers.__version__) >= pv.Version("4.53"): + class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding From 50c9f9da3b2cac91fea0b20d944fc8c664f1163b Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 27 Jun 2025 17:40:11 +0200 Subject: [PATCH 9/9] g --- .../torch_export_patches/patches/patch_transformers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index e4b15154..7ddbaa06 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -877,12 +877,15 @@ class patched_PhiRotaryEmbedding(common_RotaryEmbedding): _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding -if pv.Version(transformers.__version__) >= pv.Version("4.52"): +if pv.Version(transformers.__version__) >= pv.Version("4.51"): class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding + +if pv.Version(transformers.__version__) >= pv.Version("4.52"): + class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = (