Skip to content

Commit 04686de

Browse files
committed
Patch LlamaRotaryEmbedding
1 parent 73f5f8f commit 04686de

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -814,10 +814,7 @@ def wrapper(self, x, position_ids):
814814
return wrapper
815815

816816

817-
class patched_Phi3RotaryEmbedding(torch.nn.Module):
818-
_PATCHES_ = ["forward"]
819-
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
820-
817+
class common_RotaryEmbedding(torch.nn.Module):
821818
@torch.no_grad()
822819
@patched_dynamic_rope_update
823820
def forward(self, x, position_ids):
@@ -843,6 +840,16 @@ def forward(self, x, position_ids):
843840
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
844841

845842

843+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
844+
_PATCHES_ = ["forward"]
845+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
846+
847+
848+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
849+
_PATCHES_ = ["forward"]
850+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
851+
852+
846853
class patched_IdeficsEmbedding(torch.nn.Module):
847854
_PATCHES_ = ["forward"]
848855
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding

0 commit comments

Comments
 (0)