Skip to content

Commit feac0fa

Browse files
committed
spell
1 parent 18ff67a commit feac0fa

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

_doc/recipes/plot_export_dim1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def forward(self, x, y, z):
7575

7676

7777
# %%
78-
# Final try with pathes...
79-
# ++++++++++++++++++++++++
78+
# Final try with patches...
79+
# +++++++++++++++++++++++++
8080

8181
print("-- export shape:", string_type((x, y, z), with_shape=True))
8282
print("-- dynamic shapes:", string_type((ds, ds, ds)))

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,10 +1351,6 @@ def patched_sdpa_attention_forward(
13511351
"`sdpa` attention does not support `output_attentions=True`."
13521352
" Please set your attention to `eager` if you want any of these features."
13531353
)
1354-
torch._check(
1355-
attention_mask is None or attention_mask.shape[3] == key.shape[2],
1356-
"Attention mask shape incompatible with key shape.",
1357-
)
13581354
torch._check(
13591355
query.shape[0] == key.shape[0] or query.shape[0] == 1,
13601356
lambda: (
@@ -1385,6 +1381,11 @@ def patched_sdpa_attention_forward(
13851381
if attention_mask is not None and attention_mask.ndim == 4:
13861382
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
13871383

1384+
torch._check(
1385+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
1386+
lambda: "Attention mask shape incompatible with key shape.",
1387+
)
1388+
13881389
if patch_sdpa_is_causal:
13891390
# transformers>=4.55
13901391
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)

0 commit comments

Comments
 (0)