Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.2
+++++

* :pr:`170`: patches LlamaRotaryEmbedding
* :pr:`168`, :pr:`169`: introduces patch_diffusers
* :pr:`166`: improves handling of StaticCache
* :pr:`165`: support for task text-to-image
Expand Down
65 changes: 61 additions & 4 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import packaging.version as pv
import torch
import transformers
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
Expand Down Expand Up @@ -814,10 +815,7 @@ def wrapper(self, x, position_ids):
return wrapper


class patched_Phi3RotaryEmbedding(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding

class common_RotaryEmbedding(torch.nn.Module):
@torch.no_grad()
@patched_dynamic_rope_update
def forward(self, x, position_ids):
Expand All @@ -843,6 +841,65 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding


if pv.Version(transformers.__version__) >= pv.Version("4.52"):

class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding

class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding


class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding


class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding


class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding


class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding


if pv.Version(transformers.__version__) >= pv.Version("4.51"):

class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding


if pv.Version(transformers.__version__) >= pv.Version("4.52"):

class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = (
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
)


if pv.Version(transformers.__version__) >= pv.Version("4.53"):

class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding


class patched_IdeficsEmbedding(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
Expand Down
Loading