Skip to content

Commit 3430eb5

Browse files
committed
only examine attention_mask shape when it's available
1 parent 2badb72 commit 3430eb5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1947,7 +1947,9 @@ def patched_sdpa_attention_forward(
19471947

19481948
# From causal_mask generation, attention_mask is 4D, and the last dim
19491949
# should be the same as key's seq_len
1950-
torch._check(attention_mask.shape[3] == key.shape[2])
1950+
torch._check(
1951+
attention_mask.shape[3] == key.shape[2] if attention_mask is not None else True
1952+
)
19511953

19521954
attn_output = torch.nn.functional.scaled_dot_product_attention(
19531955
query,

0 commit comments

Comments
 (0)