44from functools import wraps
55from typing import Callable , List , Optional , Tuple
66import packaging .version as pv
7+ from sklearn import logger
78import torch
89import transformers
910from transformers .modeling_attn_mask_utils import AttentionMaskConverter
@@ -986,7 +987,7 @@ def wrapper(self, x, position_ids):
986987 return wrapper
987988
988989
989- def common_eager_attention_forward (
990+ def _common_eager_attention_forward (
990991 module : torch .nn .Module ,
991992 query : torch .Tensor ,
992993 key : torch .Tensor ,
@@ -1033,7 +1034,7 @@ def patched_model_bart_eager_attention_forward(
10331034 ** kwargs ,
10341035):
10351036 """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
1036- return common_eager_attention_forward (
1037+ return _common_eager_attention_forward (
10371038 module ,
10381039 query ,
10391040 key ,
@@ -1058,7 +1059,7 @@ def patched_modeling_marian_eager_attention_forward(
10581059 ** kwargs ,
10591060):
10601061 """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
1061- return common_eager_attention_forward (
1062+ return _common_eager_attention_forward (
10621063 module ,
10631064 query ,
10641065 key ,
@@ -1629,3 +1630,103 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
16291630 batch_size , sequence_length , hidden_dim
16301631 )
16311632 return final_hidden_states , router_logits
1633+
1634+
1635+ ##### Attention #####
1636+
1637+ try :
1638+ import transformers .modeling_utils
1639+
1640+ patch_modeling_utils = True
1641+
1642+ from transformers .integrations .sdpa_attention import use_gqa_in_sdpa , repeat_kv
1643+
1644+ except ImportError :
1645+ patch_modeling_utils = False
1646+
1647+ if patch_modeling_utils :
1648+
1649+ def patched_sdpa_attention_forward (
1650+ module : torch .nn .Module ,
1651+ query : torch .Tensor ,
1652+ key : torch .Tensor ,
1653+ value : torch .Tensor ,
1654+ attention_mask : Optional [torch .Tensor ],
1655+ dropout : float = 0.0 ,
1656+ scaling : Optional [float ] = None ,
1657+ is_causal : Optional [bool ] = None ,
1658+ ** kwargs ,
1659+ ) -> tuple [torch .Tensor , None ]:
1660+ """manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```."""
1661+ if kwargs .get ("output_attentions" , False ) or kwargs .get ("head_mask" ) is not None :
1662+ logger .warning_once (
1663+ "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
1664+ " Please set your attention to `eager` if you want any of these features."
1665+ )
1666+ sdpa_kwargs = {}
1667+ if hasattr (module , "num_key_value_groups" ):
1668+ if not use_gqa_in_sdpa (attention_mask , key ):
1669+ key = repeat_kv (key , module .num_key_value_groups )
1670+ value = repeat_kv (value , module .num_key_value_groups )
1671+ else :
1672+ sdpa_kwargs = {"enable_gqa" : True }
1673+
1674+ if attention_mask is not None and attention_mask .ndim == 4 :
1675+ attention_mask = attention_mask [:, :, :, : key .shape [- 2 ]]
1676+
1677+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1678+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1679+ # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
1680+ if is_causal is None :
1681+ # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
1682+ # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683+ def is_causal_is_true (
1684+ query , key , value , attention_mask , dropout , scaling , ** sdpa_kwargs
1685+ ):
1686+ return torch .nn .functional .scaled_dot_product_attention (
1687+ query ,
1688+ key ,
1689+ value ,
1690+ attn_mask = attention_mask ,
1691+ dropout_p = dropout ,
1692+ scale = scaling ,
1693+ is_causal = True ,
1694+ ** sdpa_kwargs ,
1695+ )
1696+
1697+ def is_causal_is_false (
1698+ query , key , value , attention_mask , dropout , scaling , ** sdpa_kwargs
1699+ ):
1700+ return torch .nn .functional .scaled_dot_product_attention (
1701+ query ,
1702+ key ,
1703+ value ,
1704+ attn_mask = attention_mask ,
1705+ dropout_p = dropout ,
1706+ scale = scaling ,
1707+ is_causal = False ,
1708+ ** sdpa_kwargs ,
1709+ )
1710+
1711+ attn_output = torch .cond (
1712+ query .shape [2 ] > 1
1713+ and attention_mask is None
1714+ and getattr (module , "is_causal" , True ),
1715+ is_causal_is_true ,
1716+ is_causal_is_false ,
1717+ [query , key , value , attention_mask , dropout , scaling ],
1718+ )
1719+ else :
1720+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1721+ query ,
1722+ key ,
1723+ value ,
1724+ attn_mask = attention_mask ,
1725+ dropout_p = dropout ,
1726+ scale = scaling ,
1727+ is_causal = is_causal ,
1728+ ** sdpa_kwargs ,
1729+ )
1730+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1731+
1732+ return attn_output , None
0 commit comments