Skip to content

Commit 4910cba

Browse files
committed
add one more patch
1 parent cb27ace commit 4910cba

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def wrapper(self, x, position_ids):
864864
return wrapper
865865

866866

867-
def patched_model_bart_eager_attention_forward(
867+
def common_eager_attention_forward(
868868
module: torch.nn.Module,
869869
query: torch.Tensor,
870870
key: torch.Tensor,
@@ -875,7 +875,6 @@ def patched_model_bart_eager_attention_forward(
875875
head_mask: Optional[torch.Tensor] = None,
876876
**kwargs,
877877
):
878-
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
879878
if scaling is None:
880879
scaling = query.size(-1) ** -0.5
881880

@@ -900,6 +899,56 @@ def patched_model_bart_eager_attention_forward(
900899
return attn_output, attn_weights
901900

902901

902+
def patched_model_bart_eager_attention_forward(
903+
module: torch.nn.Module,
904+
query: torch.Tensor,
905+
key: torch.Tensor,
906+
value: torch.Tensor,
907+
attention_mask: Optional[torch.Tensor],
908+
scaling: Optional[float] = None,
909+
dropout: float = 0.0,
910+
head_mask: Optional[torch.Tensor] = None,
911+
**kwargs,
912+
):
913+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
914+
return common_eager_attention_forward(
915+
module,
916+
query,
917+
key,
918+
value,
919+
attention_mask=attention_mask,
920+
scaling=scaling,
921+
dropout=dropout,
922+
head_mask=head_mask,
923+
**kwargs,
924+
)
925+
926+
927+
def patched_modeling_marian_eager_attention_forward(
928+
module: torch.nn.Module,
929+
query: torch.Tensor,
930+
key: torch.Tensor,
931+
value: torch.Tensor,
932+
attention_mask: Optional[torch.Tensor],
933+
scaling: Optional[float] = None,
934+
dropout: float = 0.0,
935+
head_mask: Optional[torch.Tensor] = None,
936+
**kwargs,
937+
):
938+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
939+
return common_eager_attention_forward(
940+
module,
941+
query,
942+
key,
943+
value,
944+
attention_mask=attention_mask,
945+
scaling=scaling,
946+
dropout=dropout,
947+
head_mask=head_mask,
948+
**kwargs,
949+
)
950+
951+
903952
class common_RotaryEmbedding(torch.nn.Module):
904953
@torch.no_grad()
905954
@patched_dynamic_rope_update

0 commit comments

Comments
 (0)