@@ -864,7 +864,7 @@ def wrapper(self, x, position_ids):
864864 return wrapper
865865
866866
867- def patched_model_bart_eager_attention_forward (
867+ def common_eager_attention_forward (
868868 module : torch .nn .Module ,
869869 query : torch .Tensor ,
870870 key : torch .Tensor ,
@@ -875,7 +875,6 @@ def patched_model_bart_eager_attention_forward(
875875 head_mask : Optional [torch .Tensor ] = None ,
876876 ** kwargs ,
877877):
878- """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
879878 if scaling is None :
880879 scaling = query .size (- 1 ) ** - 0.5
881880
@@ -900,6 +899,56 @@ def patched_model_bart_eager_attention_forward(
900899 return attn_output , attn_weights
901900
902901
902+ def patched_model_bart_eager_attention_forward (
903+ module : torch .nn .Module ,
904+ query : torch .Tensor ,
905+ key : torch .Tensor ,
906+ value : torch .Tensor ,
907+ attention_mask : Optional [torch .Tensor ],
908+ scaling : Optional [float ] = None ,
909+ dropout : float = 0.0 ,
910+ head_mask : Optional [torch .Tensor ] = None ,
911+ ** kwargs ,
912+ ):
913+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
914+ return common_eager_attention_forward (
915+ module ,
916+ query ,
917+ key ,
918+ value ,
919+ attention_mask = attention_mask ,
920+ scaling = scaling ,
921+ dropout = dropout ,
922+ head_mask = head_mask ,
923+ ** kwargs ,
924+ )
925+
926+
927+ def patched_modeling_marian_eager_attention_forward (
928+ module : torch .nn .Module ,
929+ query : torch .Tensor ,
930+ key : torch .Tensor ,
931+ value : torch .Tensor ,
932+ attention_mask : Optional [torch .Tensor ],
933+ scaling : Optional [float ] = None ,
934+ dropout : float = 0.0 ,
935+ head_mask : Optional [torch .Tensor ] = None ,
936+ ** kwargs ,
937+ ):
938+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
939+ return common_eager_attention_forward (
940+ module ,
941+ query ,
942+ key ,
943+ value ,
944+ attention_mask = attention_mask ,
945+ scaling = scaling ,
946+ dropout = dropout ,
947+ head_mask = head_mask ,
948+ ** kwargs ,
949+ )
950+
951+
903952class common_RotaryEmbedding (torch .nn .Module ):
904953 @torch .no_grad ()
905954 @patched_dynamic_rope_update
0 commit comments