Skip to content

Commit f413ea7

Browse files
committed
draft-patched_sdpa
1 parent cd1a19f commit f413ea7

File tree

2 files changed

+139
-3
lines changed

2 files changed

+139
-3
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ def torch_export_patches(
415415
except ImportError:
416416
masking_utils = None
417417

418+
try:
419+
import transformers.modeling_utils as modeling_utils
420+
except ImportError:
421+
modeling_utils = None
422+
418423
if verbose:
419424
import transformers
420425

@@ -509,6 +514,23 @@ def torch_export_patches(
509514
patch_transformers_list.patched_sdpa_mask_recent_torch
510515
)
511516

517+
if (
518+
modeling_utils
519+
and patch_transformers_list.patch_modeling_utils
520+
and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS
521+
):
522+
if verbose:
523+
print(
524+
"[torch_export_patches] patches "
525+
"transformers.modeling_utils.sdpa_attention_forward"
526+
)
527+
f_transformers_sdpa_attention_forward = modeling_utils.ALL_ATTENTION_FUNCTIONS[
528+
"sdpa"
529+
]
530+
modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = (
531+
patch_transformers_list.patched_sdpa_attention_forward
532+
)
533+
512534
if custom_patches:
513535
if verbose:
514536
print("[torch_export_patches] applies custom patches")
@@ -688,6 +710,19 @@ def torch_export_patches(
688710
"transformers.masking_utils.sdpa_mask "
689711
"in ALL_MASK_ATTENTION_FUNCTIONS"
690712
)
713+
if (
714+
modeling_utils
715+
and patch_transformers_list.patch_modeling_utils
716+
and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS
717+
):
718+
modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = (
719+
f_transformers_sdpa_attention_forward
720+
)
721+
if verbose:
722+
print(
723+
"[torch_export_patches] restored "
724+
"transformers.modeling_utils.sdpa_attention_forward"
725+
)
691726

692727
########
693728
# caches

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps
55
from typing import Callable, List, Optional, Tuple
66
import packaging.version as pv
7+
from sklearn import logger
78
import torch
89
import transformers
910
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -986,7 +987,7 @@ def wrapper(self, x, position_ids):
986987
return wrapper
987988

988989

989-
def common_eager_attention_forward(
990+
def _common_eager_attention_forward(
990991
module: torch.nn.Module,
991992
query: torch.Tensor,
992993
key: torch.Tensor,
@@ -1033,7 +1034,7 @@ def patched_model_bart_eager_attention_forward(
10331034
**kwargs,
10341035
):
10351036
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
1036-
return common_eager_attention_forward(
1037+
return _common_eager_attention_forward(
10371038
module,
10381039
query,
10391040
key,
@@ -1058,7 +1059,7 @@ def patched_modeling_marian_eager_attention_forward(
10581059
**kwargs,
10591060
):
10601061
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
1061-
return common_eager_attention_forward(
1062+
return _common_eager_attention_forward(
10621063
module,
10631064
query,
10641065
key,
@@ -1629,3 +1630,103 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
16291630
batch_size, sequence_length, hidden_dim
16301631
)
16311632
return final_hidden_states, router_logits
1633+
1634+
1635+
##### Attention #####
1636+
1637+
try:
1638+
import transformers.modeling_utils
1639+
1640+
patch_modeling_utils = True
1641+
1642+
from transformers.integrations.sdpa_attention import use_gqa_in_sdpa, repeat_kv
1643+
1644+
except ImportError:
1645+
patch_modeling_utils = False
1646+
1647+
if patch_modeling_utils:
1648+
1649+
def patched_sdpa_attention_forward(
1650+
module: torch.nn.Module,
1651+
query: torch.Tensor,
1652+
key: torch.Tensor,
1653+
value: torch.Tensor,
1654+
attention_mask: Optional[torch.Tensor],
1655+
dropout: float = 0.0,
1656+
scaling: Optional[float] = None,
1657+
is_causal: Optional[bool] = None,
1658+
**kwargs,
1659+
) -> tuple[torch.Tensor, None]:
1660+
"""manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```."""
1661+
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
1662+
logger.warning_once(
1663+
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
1664+
" Please set your attention to `eager` if you want any of these features."
1665+
)
1666+
sdpa_kwargs = {}
1667+
if hasattr(module, "num_key_value_groups"):
1668+
if not use_gqa_in_sdpa(attention_mask, key):
1669+
key = repeat_kv(key, module.num_key_value_groups)
1670+
value = repeat_kv(value, module.num_key_value_groups)
1671+
else:
1672+
sdpa_kwargs = {"enable_gqa": True}
1673+
1674+
if attention_mask is not None and attention_mask.ndim == 4:
1675+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1676+
1677+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1678+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1679+
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
1680+
if is_causal is None:
1681+
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
1682+
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683+
def is_causal_is_true(
1684+
query, key, value, attention_mask, dropout, scaling, **sdpa_kwargs
1685+
):
1686+
return torch.nn.functional.scaled_dot_product_attention(
1687+
query,
1688+
key,
1689+
value,
1690+
attn_mask=attention_mask,
1691+
dropout_p=dropout,
1692+
scale=scaling,
1693+
is_causal=True,
1694+
**sdpa_kwargs,
1695+
)
1696+
1697+
def is_causal_is_false(
1698+
query, key, value, attention_mask, dropout, scaling, **sdpa_kwargs
1699+
):
1700+
return torch.nn.functional.scaled_dot_product_attention(
1701+
query,
1702+
key,
1703+
value,
1704+
attn_mask=attention_mask,
1705+
dropout_p=dropout,
1706+
scale=scaling,
1707+
is_causal=False,
1708+
**sdpa_kwargs,
1709+
)
1710+
1711+
attn_output = torch.cond(
1712+
query.shape[2] > 1
1713+
and attention_mask is None
1714+
and getattr(module, "is_causal", True),
1715+
is_causal_is_true,
1716+
is_causal_is_false,
1717+
[query, key, value, attention_mask, dropout, scaling],
1718+
)
1719+
else:
1720+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1721+
query,
1722+
key,
1723+
value,
1724+
attn_mask=attention_mask,
1725+
dropout_p=dropout,
1726+
scale=scaling,
1727+
is_causal=is_causal,
1728+
**sdpa_kwargs,
1729+
)
1730+
attn_output = attn_output.transpose(1, 2).contiguous()
1731+
1732+
return attn_output, None

0 commit comments

Comments
 (0)