@@ -1351,6 +1351,25 @@ def patched_sdpa_attention_forward(
13511351 "`sdpa` attention does not support `output_attentions=True`."
13521352 " Please set your attention to `eager` if you want any of these features."
13531353 )
1354+ torch ._check (
1355+ attention_mask is None or attention_mask .shape [3 ] == key .shape [2 ],
1356+ "Attention mask shape incompatible with key shape." ,
1357+ )
1358+ torch ._check (
1359+ query .shape [0 ] == key .shape [0 ] or query .shape [0 ] == 1 ,
1360+ lambda : (
1361+ f"broadcast issue query (1): { query .shape } , key: { key .shape } , "
1362+ f"value: { value .shape } "
1363+ ),
1364+ )
1365+ torch ._check (
1366+ key .shape [0 ] == value .shape [0 ] or key .shape [0 ] == 1 ,
1367+ lambda : (
1368+ f"broadcast issue query (2): { query .shape } , key: { key .shape } , "
1369+ f"value: { value .shape } "
1370+ ),
1371+ )
1372+
13541373 sdpa_kwargs = {}
13551374 if hasattr (module , "num_key_value_groups" ):
13561375 if not transformers .integrations .sdpa_attention .use_gqa_in_sdpa (attention_mask , key ):
@@ -1367,49 +1386,50 @@ def patched_sdpa_attention_forward(
13671386 attention_mask = attention_mask [:, :, :, : key .shape [- 2 ]]
13681387
13691388 if patch_is_causal :
1389+ # transformers>=4.55
13701390 is_causal = is_causal if is_causal is not None else getattr (module , "is_causal" , True )
13711391
13721392 # PATCHED: remove the test query.shape[2] > 1
13731393 # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
13741394 # and we split the test to keep the minimum in torch.cond
13751395 is_causal = attention_mask is None and is_causal
1376- elif is_causal is None :
1377- is_causal = attention_mask is None
13781396
1379- torch ._check (
1380- attention_mask is None or attention_mask .shape [3 ] == key .shape [2 ],
1381- "Attention mask shape incompatible with key shape." ,
1382- )
1383- torch ._check (
1384- query .shape [0 ] == key .shape [0 ] or query .shape [0 ] == 1 ,
1385- lambda : (
1386- f"broadcast issue query (1): { query .shape } , key: { key .shape } , "
1387- f"value: { value .shape } "
1388- ),
1389- )
1390- torch ._check (
1391- key .shape [0 ] == value .shape [0 ] or key .shape [0 ] == 1 ,
1392- lambda : (
1393- f"broadcast issue query (2): { query .shape } , key: { key .shape } , "
1394- f"value: { value .shape } "
1395- ),
1396- )
1397- if not is_causal or not patch_is_causal :
1398- return (
1399- torch .nn .functional .scaled_dot_product_attention (
1400- query ,
1401- key ,
1402- value ,
1403- attn_mask = attention_mask ,
1404- dropout_p = dropout ,
1405- scale = scaling ,
1406- is_causal = is_causal ,
1407- ** sdpa_kwargs ,
1397+ if not is_causal :
1398+ return (
1399+ torch .nn .functional .scaled_dot_product_attention (
1400+ query ,
1401+ key ,
1402+ value ,
1403+ attn_mask = attention_mask ,
1404+ dropout_p = dropout ,
1405+ scale = scaling ,
1406+ is_causal = is_causal ,
1407+ ** sdpa_kwargs ,
1408+ )
1409+ .transpose (1 , 2 )
1410+ .contiguous (),
1411+ None ,
1412+ )
1413+ else :
1414+ # transformers<4.55
1415+ if is_causal is None and attention_mask is not None :
1416+ is_causal = False
1417+ if is_causal is not None :
1418+ return (
1419+ torch .nn .functional .scaled_dot_product_attention (
1420+ query ,
1421+ key ,
1422+ value ,
1423+ attn_mask = attention_mask ,
1424+ dropout_p = dropout ,
1425+ scale = scaling ,
1426+ is_causal = is_causal ,
1427+ ** sdpa_kwargs ,
1428+ )
1429+ .transpose (1 , 2 )
1430+ .contiguous (),
1431+ None ,
14081432 )
1409- .transpose (1 , 2 )
1410- .contiguous (),
1411- None ,
1412- )
14131433
14141434 # To avoid the following errors:
14151435 # is_causal=query.shape[2] > 1
0 commit comments