@@ -1680,53 +1680,26 @@ def patched_sdpa_attention_forward(
16801680 if is_causal is None :
16811681 # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
16821682 # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683- def is_causal_is_true (
1684- query , key , value , attention_mask , dropout , scaling , ** sdpa_kwargs
1685- ):
1686- return torch .nn .functional .scaled_dot_product_attention (
1687- query ,
1688- key ,
1689- value ,
1690- attn_mask = attention_mask ,
1691- dropout_p = dropout ,
1692- scale = scaling ,
1693- is_causal = True ,
1694- ** sdpa_kwargs ,
1695- )
1696-
1697- def is_causal_is_false (
1698- query , key , value , attention_mask , dropout , scaling , ** sdpa_kwargs
1699- ):
1700- return torch .nn .functional .scaled_dot_product_attention (
1701- query ,
1702- key ,
1703- value ,
1704- attn_mask = attention_mask ,
1705- dropout_p = dropout ,
1706- scale = scaling ,
1707- is_causal = False ,
1708- ** sdpa_kwargs ,
1709- )
1710-
1711- attn_output = torch .cond (
1712- query .shape [2 ] > 1
1713- and attention_mask is None
1714- and getattr (module , "is_causal" , True ),
1715- is_causal_is_true ,
1716- is_causal_is_false ,
1717- [query , key , value , attention_mask , dropout , scaling ],
1718- )
1719- else :
1720- attn_output = torch .nn .functional .scaled_dot_product_attention (
1721- query ,
1722- key ,
1723- value ,
1724- attn_mask = attention_mask ,
1725- dropout_p = dropout ,
1726- scale = scaling ,
1727- is_causal = is_causal ,
1728- ** sdpa_kwargs ,
1729- )
1683+ # is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
1684+ # NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention
1685+ # so we simplify the condition to:
1686+ is_causal = attention_mask is None and getattr (module , "is_causal" , True )
1687+
1688+ # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
1689+ # We convert it to a bool for the SDPA kernel that only accepts bools.
1690+ if torch .jit .is_tracing () and isinstance (is_causal , torch .Tensor ):
1691+ is_causal = is_causal .item ()
1692+
1693+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1694+ query ,
1695+ key ,
1696+ value ,
1697+ attn_mask = attention_mask ,
1698+ dropout_p = dropout ,
1699+ scale = scaling ,
1700+ is_causal = is_causal ,
1701+ ** sdpa_kwargs ,
1702+ )
17301703 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
17311704
17321705 return attn_output , None
0 commit comments