We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e7b9dc1 commit ae0b77cCopy full SHA for ae0b77c
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
@@ -1313,7 +1313,7 @@ def patched_sdpa_attention_forward(
1313
is_causal = attention_mask is None and is_causal
1314
1315
torch._check(
1316
- attention_mask.shape[3] == key.shape[2],
+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
1317
"Attention mask shape incompatible with key shape.",
1318
)
1319
attn_output = torch.nn.functional.scaled_dot_product_attention(
0 commit comments