Skip to content

Commit 6eb85b7

Browse files
authored
Patches LlamaRotaryEmbedding (#170)
* Patch LlamaRotaryEmbedding * adding more patched rotary * fix typo * use version * fix version * typo * fix issues * fix * g
1 parent 73f5f8f commit 6eb85b7

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`170`: patches LlamaRotaryEmbedding
78
* :pr:`168`, :pr:`169`: introduces patch_diffusers
89
* :pr:`166`: improves handling of StaticCache
910
* :pr:`165`: support for task text-to-image

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
from functools import wraps
44
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
import packaging.version as pv
56
import torch
67
import transformers
78
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -814,10 +815,7 @@ def wrapper(self, x, position_ids):
814815
return wrapper
815816

816817

817-
class patched_Phi3RotaryEmbedding(torch.nn.Module):
818-
_PATCHES_ = ["forward"]
819-
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
820-
818+
class common_RotaryEmbedding(torch.nn.Module):
821819
@torch.no_grad()
822820
@patched_dynamic_rope_update
823821
def forward(self, x, position_ids):
@@ -843,6 +841,65 @@ def forward(self, x, position_ids):
843841
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
844842

845843

844+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
845+
_PATCHES_ = ["forward"]
846+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
847+
848+
849+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
850+
851+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
852+
_PATCHES_ = ["forward"]
853+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
854+
855+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
856+
_PATCHES_ = ["forward"]
857+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
858+
859+
860+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
861+
_PATCHES_ = ["forward"]
862+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
863+
864+
865+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
866+
_PATCHES_ = ["forward"]
867+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
868+
869+
870+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
871+
_PATCHES_ = ["forward"]
872+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
873+
874+
875+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
876+
_PATCHES_ = ["forward"]
877+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
878+
879+
880+
if pv.Version(transformers.__version__) >= pv.Version("4.51"):
881+
882+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
883+
_PATCHES_ = ["forward"]
884+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
885+
886+
887+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
888+
889+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
890+
_PATCHES_ = ["forward"]
891+
_PATCHED_CLASS_ = (
892+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
893+
)
894+
895+
896+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
897+
898+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
899+
_PATCHES_ = ["forward"]
900+
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
901+
902+
846903
class patched_IdeficsEmbedding(torch.nn.Module):
847904
_PATCHES_ = ["forward"]
848905
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding

0 commit comments

Comments
 (0)