22from dataclasses import dataclass
33from functools import wraps
44from typing import Any , Callable , Dict , List , Optional , Tuple
5+ import packaging .version as pv
56import torch
67import transformers
78from transformers .modeling_attn_mask_utils import AttentionMaskConverter
@@ -814,10 +815,7 @@ def wrapper(self, x, position_ids):
814815 return wrapper
815816
816817
817- class patched_Phi3RotaryEmbedding (torch .nn .Module ):
818- _PATCHES_ = ["forward" ]
819- _PATCHED_CLASS_ = transformers .models .phi3 .modeling_phi3 .Phi3RotaryEmbedding
820-
818+ class common_RotaryEmbedding (torch .nn .Module ):
821819 @torch .no_grad ()
822820 @patched_dynamic_rope_update
823821 def forward (self , x , position_ids ):
@@ -843,6 +841,65 @@ def forward(self, x, position_ids):
843841 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
844842
845843
844+ class patched_GemmaRotaryEmbedding (common_RotaryEmbedding ):
845+ _PATCHES_ = ["forward" ]
846+ _PATCHED_CLASS_ = transformers .models .gemma .modeling_gemma .GemmaRotaryEmbedding
847+
848+
849+ if pv .Version (transformers .__version__ ) >= pv .Version ("4.52" ):
850+
851+ class patched_Gemma2RotaryEmbedding (common_RotaryEmbedding ):
852+ _PATCHES_ = ["forward" ]
853+ _PATCHED_CLASS_ = transformers .models .gemma2 .modeling_gemma2 .Gemma2RotaryEmbedding
854+
855+ class patched_Gemma3RotaryEmbedding (common_RotaryEmbedding ):
856+ _PATCHES_ = ["forward" ]
857+ _PATCHED_CLASS_ = transformers .models .gemma3 .modeling_gemma3 .Gemma3RotaryEmbedding
858+
859+
860+ class patched_LlamaRotaryEmbedding (common_RotaryEmbedding ):
861+ _PATCHES_ = ["forward" ]
862+ _PATCHED_CLASS_ = transformers .models .llama .modeling_llama .LlamaRotaryEmbedding
863+
864+
865+ class patched_MistralRotaryEmbedding (common_RotaryEmbedding ):
866+ _PATCHES_ = ["forward" ]
867+ _PATCHED_CLASS_ = transformers .models .mistral .modeling_mistral .MistralRotaryEmbedding
868+
869+
870+ class patched_MixtralRotaryEmbedding (common_RotaryEmbedding ):
871+ _PATCHES_ = ["forward" ]
872+ _PATCHED_CLASS_ = transformers .models .mixtral .modeling_mixtral .MixtralRotaryEmbedding
873+
874+
875+ class patched_PhiRotaryEmbedding (common_RotaryEmbedding ):
876+ _PATCHES_ = ["forward" ]
877+ _PATCHED_CLASS_ = transformers .models .phi .modeling_phi .PhiRotaryEmbedding
878+
879+
880+ if pv .Version (transformers .__version__ ) >= pv .Version ("4.51" ):
881+
882+ class patched_Phi3RotaryEmbedding (common_RotaryEmbedding ):
883+ _PATCHES_ = ["forward" ]
884+ _PATCHED_CLASS_ = transformers .models .phi3 .modeling_phi3 .Phi3RotaryEmbedding
885+
886+
887+ if pv .Version (transformers .__version__ ) >= pv .Version ("4.52" ):
888+
889+ class patched_Phi4MultimodalRotaryEmbedding (common_RotaryEmbedding ):
890+ _PATCHES_ = ["forward" ]
891+ _PATCHED_CLASS_ = (
892+ transformers .models .phi4_multimodal .modeling_phi4_multimodal .Phi4MultimodalRotaryEmbedding
893+ )
894+
895+
896+ if pv .Version (transformers .__version__ ) >= pv .Version ("4.53" ):
897+
898+ class patched_SmolLM3RotaryEmbedding (common_RotaryEmbedding ):
899+ _PATCHES_ = ["forward" ]
900+ _PATCHED_CLASS_ = transformers .models .smollm3 .modeling_smollm3 .SmolLM3RotaryEmbedding
901+
902+
846903class patched_IdeficsEmbedding (torch .nn .Module ):
847904 _PATCHES_ = ["forward" ]
848905 _PATCHED_CLASS_ = transformers .models .idefics .modeling_idefics .IdeficsEmbedding
0 commit comments