Skip to content

Commit 87e9693

Browse files
committed
adding more patched rotary
1 parent 04686de commit 87e9693

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`170`: patches LlamaRotaryEmbedding
78
* :pr:`168`, :pr:`169`: introduces patch_diffusers
89
* :pr:`166`: improves handling of StaticCache
910
* :pr:`165`: support for task text-to-image

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,16 +840,58 @@ def forward(self, x, position_ids):
840840
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
841841

842842

843+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
844+
_PATCHES_ = ["forward"]
845+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
846+
847+
848+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
849+
_PATCHES_ = ["forward"]
850+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
851+
852+
853+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
854+
_PATCHES_ = ["forward"]
855+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
856+
857+
843858
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
844859
_PATCHES_ = ["forward"]
845860
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
846861

847862

863+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
864+
_PATCHES_ = ["forward"]
865+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
866+
867+
868+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
869+
_PATCHES_ = ["forward"]
870+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
871+
872+
873+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
874+
_PATCHES_ = ["forward"]
875+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
876+
877+
848878
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
849879
_PATCHES_ = ["forward"]
850880
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
851881

852882

883+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
884+
_PATCHES_ = ["forward"]
885+
_PATCHED_CLASS_ = (
886+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
887+
)
888+
889+
890+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
891+
_PATCHES_ = ["forward"]
892+
_PATCHED_CLASS_ = transformers.models.smallm3.modeling_smallm3.SmolLM3RotaryEmbedding
893+
894+
853895
class patched_IdeficsEmbedding(torch.nn.Module):
854896
_PATCHES_ = ["forward"]
855897
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding

0 commit comments

Comments
 (0)