Skip to content

Commit 8bd2fa1

Browse files
committed
set is_causal
1 parent f413ea7 commit 8bd2fa1

File tree

1 file changed

+20
-47
lines changed

1 file changed

+20
-47
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,53 +1680,26 @@ def patched_sdpa_attention_forward(
16801680
if is_causal is None:
16811681
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
16821682
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683-
def is_causal_is_true(
1684-
query, key, value, attention_mask, dropout, scaling, **sdpa_kwargs
1685-
):
1686-
return torch.nn.functional.scaled_dot_product_attention(
1687-
query,
1688-
key,
1689-
value,
1690-
attn_mask=attention_mask,
1691-
dropout_p=dropout,
1692-
scale=scaling,
1693-
is_causal=True,
1694-
**sdpa_kwargs,
1695-
)
1696-
1697-
def is_causal_is_false(
1698-
query, key, value, attention_mask, dropout, scaling, **sdpa_kwargs
1699-
):
1700-
return torch.nn.functional.scaled_dot_product_attention(
1701-
query,
1702-
key,
1703-
value,
1704-
attn_mask=attention_mask,
1705-
dropout_p=dropout,
1706-
scale=scaling,
1707-
is_causal=False,
1708-
**sdpa_kwargs,
1709-
)
1710-
1711-
attn_output = torch.cond(
1712-
query.shape[2] > 1
1713-
and attention_mask is None
1714-
and getattr(module, "is_causal", True),
1715-
is_causal_is_true,
1716-
is_causal_is_false,
1717-
[query, key, value, attention_mask, dropout, scaling],
1718-
)
1719-
else:
1720-
attn_output = torch.nn.functional.scaled_dot_product_attention(
1721-
query,
1722-
key,
1723-
value,
1724-
attn_mask=attention_mask,
1725-
dropout_p=dropout,
1726-
scale=scaling,
1727-
is_causal=is_causal,
1728-
**sdpa_kwargs,
1729-
)
1683+
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
1684+
# NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention
1685+
# so we simplify the condition to:
1686+
is_causal = attention_mask is None and getattr(module, "is_causal", True)
1687+
1688+
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
1689+
# We convert it to a bool for the SDPA kernel that only accepts bools.
1690+
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
1691+
is_causal = is_causal.item()
1692+
1693+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1694+
query,
1695+
key,
1696+
value,
1697+
attn_mask=attention_mask,
1698+
dropout_p=dropout,
1699+
scale=scaling,
1700+
is_causal=is_causal,
1701+
**sdpa_kwargs,
1702+
)
17301703
attn_output = attn_output.transpose(1, 2).contiguous()
17311704

17321705
return attn_output, None

0 commit comments

Comments
 (0)