Skip to content

Commit d3f8c0b

Browse files
committed
fix
1 parent 9b86d3e commit d3f8c0b

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def test_sdpa_mask_recent_torch(self):
4444
got = patched_sdpa_mask_recent_torch(**kwargs)
4545
self.assertEqualArray(expected, got)
4646

47-
@requires_transformers("4.55")
4847
def test_sdpa_attention_forward_not_causal(self):
4948
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
5049
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward
@@ -76,7 +75,6 @@ def test_sdpa_attention_forward_not_causal(self):
7675
got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0]
7776
self.assertEqualArray(expected, got)
7877

79-
@requires_transformers("4.55")
8078
def test_sdpa_attention_forward_causal(self):
8179
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
8280
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,25 @@ def patched_sdpa_attention_forward(
13511351
"`sdpa` attention does not support `output_attentions=True`."
13521352
" Please set your attention to `eager` if you want any of these features."
13531353
)
1354+
torch._check(
1355+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
1356+
"Attention mask shape incompatible with key shape.",
1357+
)
1358+
torch._check(
1359+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
1360+
lambda: (
1361+
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
1362+
f"value: {value.shape}"
1363+
),
1364+
)
1365+
torch._check(
1366+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
1367+
lambda: (
1368+
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
1369+
f"value: {value.shape}"
1370+
),
1371+
)
1372+
13541373
sdpa_kwargs = {}
13551374
if hasattr(module, "num_key_value_groups"):
13561375
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
@@ -1367,49 +1386,50 @@ def patched_sdpa_attention_forward(
13671386
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
13681387

13691388
if patch_is_causal:
1389+
# transformers>=4.55
13701390
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
13711391

13721392
# PATCHED: remove the test query.shape[2] > 1
13731393
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
13741394
# and we split the test to keep the minimum in torch.cond
13751395
is_causal = attention_mask is None and is_causal
1376-
elif is_causal is None:
1377-
is_causal = attention_mask is None
13781396

1379-
torch._check(
1380-
attention_mask is None or attention_mask.shape[3] == key.shape[2],
1381-
"Attention mask shape incompatible with key shape.",
1382-
)
1383-
torch._check(
1384-
query.shape[0] == key.shape[0] or query.shape[0] == 1,
1385-
lambda: (
1386-
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
1387-
f"value: {value.shape}"
1388-
),
1389-
)
1390-
torch._check(
1391-
key.shape[0] == value.shape[0] or key.shape[0] == 1,
1392-
lambda: (
1393-
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
1394-
f"value: {value.shape}"
1395-
),
1396-
)
1397-
if not is_causal or not patch_is_causal:
1398-
return (
1399-
torch.nn.functional.scaled_dot_product_attention(
1400-
query,
1401-
key,
1402-
value,
1403-
attn_mask=attention_mask,
1404-
dropout_p=dropout,
1405-
scale=scaling,
1406-
is_causal=is_causal,
1407-
**sdpa_kwargs,
1397+
if not is_causal:
1398+
return (
1399+
torch.nn.functional.scaled_dot_product_attention(
1400+
query,
1401+
key,
1402+
value,
1403+
attn_mask=attention_mask,
1404+
dropout_p=dropout,
1405+
scale=scaling,
1406+
is_causal=is_causal,
1407+
**sdpa_kwargs,
1408+
)
1409+
.transpose(1, 2)
1410+
.contiguous(),
1411+
None,
1412+
)
1413+
else:
1414+
# transformers<4.55
1415+
if is_causal is None and attention_mask is not None:
1416+
is_causal = False
1417+
if is_causal is not None:
1418+
return (
1419+
torch.nn.functional.scaled_dot_product_attention(
1420+
query,
1421+
key,
1422+
value,
1423+
attn_mask=attention_mask,
1424+
dropout_p=dropout,
1425+
scale=scaling,
1426+
is_causal=is_causal,
1427+
**sdpa_kwargs,
1428+
)
1429+
.transpose(1, 2)
1430+
.contiguous(),
1431+
None,
14081432
)
1409-
.transpose(1, 2)
1410-
.contiguous(),
1411-
None,
1412-
)
14131433

14141434
# To avoid the following errors:
14151435
# is_causal=query.shape[2] > 1

0 commit comments

Comments
 (0)