Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.2
+++++

* :pr:`170`: patches LlamaRotaryEmbedding
* :pr:`168`, :pr:`169`: introduces patch_diffusers
* :pr:`166`: improves handling of StaticCache
* :pr:`165`: support for task text-to-image
Expand Down
60 changes: 56 additions & 4 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import packaging.version as pv
import torch
import transformers
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
Expand Down Expand Up @@ -814,10 +815,7 @@ def wrapper(self, x, position_ids):
return wrapper


class patched_Phi3RotaryEmbedding(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding

class common_RotaryEmbedding(torch.nn.Module):
@torch.no_grad()
@patched_dynamic_rope_update
def forward(self, x, position_ids):
Expand All @@ -843,6 +841,60 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding


class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding


class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding


class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding


class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding


class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding


class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding


class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding


class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = (
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
)


if pv.Version(transformers.__version__) >= pv.Version("4.53"):

class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding


class patched_IdeficsEmbedding(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
Expand Down
Loading