diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 12192a1b..6970b954 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.2 +++++ +* :pr:`170`: patches LlamaRotaryEmbedding * :pr:`168`, :pr:`169`: introduces patch_diffusers * :pr:`166`: improves handling of StaticCache * :pr:`165`: support for task text-to-image diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 017b1985..7ddbaa06 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple +import packaging.version as pv import torch import transformers from transformers.modeling_attn_mask_utils import AttentionMaskConverter @@ -814,10 +815,7 @@ def wrapper(self, x, position_ids): return wrapper -class patched_Phi3RotaryEmbedding(torch.nn.Module): - _PATCHES_ = ["forward"] - _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding - +class common_RotaryEmbedding(torch.nn.Module): @torch.no_grad() @patched_dynamic_rope_update def forward(self, x, position_ids): @@ -843,6 +841,65 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class patched_GemmaRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding + + +if pv.Version(transformers.__version__) >= pv.Version("4.52"): + + class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding + + class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding + + +class patched_LlamaRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding + + +class patched_MistralRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding + + +class patched_MixtralRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding + + +class patched_PhiRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding + + +if pv.Version(transformers.__version__) >= pv.Version("4.51"): + + class patched_Phi3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding + + +if pv.Version(transformers.__version__) >= pv.Version("4.52"): + + class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = ( + transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding + ) + + +if pv.Version(transformers.__version__) >= pv.Version("4.53"): + + class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding + + class patched_IdeficsEmbedding(torch.nn.Module): _PATCHES_ = ["forward"] _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding