Skip to content

Commit e632c4d

Browse files
authored
patches sdpa_attention_forward for transformers>=5.0 (#267)
* patch sdpa_attention_forward * doc * patch * fix * code * disable one test
1 parent b206ad4 commit e632c4d

File tree

5 files changed

+111
-8
lines changed

5 files changed

+111
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
78
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
89

910
0.7.15

_doc/examples/plot_export_hub_codellama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
# %%
4545
# The configuration.
4646

47-
print("config", get_pretrained_config(model_id))
47+
print("config", get_pretrained_config(model_id, use_only_preinstalled=unit_test_going()))
4848

4949
# %%
5050
# The task determines the set of inputs which needs

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_m_validate_model_vit_model(self):
236236
onnx_filename = data["onnx_filename"]
237237
self.assertExists(onnx_filename)
238238

239-
@requires_torch("2.9")
239+
@requires_torch("2.9.99")
240240
@hide_stdout()
241241
@ignore_warnings(FutureWarning)
242242
@requires_transformers("4.55")

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,16 @@ def torch_export_patches(
459459
except ImportError:
460460
masking_utils = None
461461

462+
try:
463+
import transformers.integrations.sdpa_attention as sdpa_attention
464+
except ImportError:
465+
sdpa_attention = None
466+
467+
try:
468+
import transformers.modeling_utils as modeling_utils
469+
except ImportError:
470+
modeling_utils = None
471+
462472
if verbose:
463473
import transformers
464474

@@ -470,7 +480,7 @@ def torch_export_patches(
470480
patch_transformers_list, verbose=verbose
471481
)
472482

473-
if (
483+
if ( # vmap
474484
masking_utils
475485
and patch_transformers_list.patch_masking_utils
476486
and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -505,7 +515,7 @@ def torch_export_patches(
505515
else:
506516
f_transformers_sdpa_mask = None
507517

508-
if (
518+
if ( # eager_mask
509519
masking_utils
510520
and patch_transformers_list.patch_masking_utils
511521
and hasattr(masking_utils, "eager_mask")
@@ -532,7 +542,7 @@ def torch_export_patches(
532542
patch_transformers_list.patched_eager_mask
533543
)
534544

535-
if (
545+
if ( # sdpa_mask
536546
masking_utils
537547
and patch_transformers_list.patch_masking_utils
538548
and hasattr(masking_utils, "sdpa_mask")
@@ -553,6 +563,29 @@ def torch_export_patches(
553563
patch_transformers_list.patched_sdpa_mask_recent_torch
554564
)
555565

566+
if ( # sdpa_attention_forward
567+
sdpa_attention is not None
568+
and modeling_utils is not None
569+
and hasattr(sdpa_attention, "sdpa_attention_forward")
570+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
571+
and hasattr(modeling_utils, "AttentionInterface")
572+
):
573+
if verbose:
574+
print(
575+
"[torch_export_patches] patches "
576+
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
577+
)
578+
f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
579+
sdpa_attention.sdpa_attention_forward = (
580+
patch_transformers_list.patched_sdpa_attention_forward
581+
)
582+
modeling_utils.sdpa_attention_forward = (
583+
patch_transformers_list.patched_sdpa_attention_forward
584+
)
585+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
586+
patch_transformers_list.patched_sdpa_attention_forward
587+
)
588+
556589
if custom_patches:
557590
if verbose:
558591
print("[torch_export_patches] applies custom patches")
@@ -662,7 +695,7 @@ def torch_export_patches(
662695
patch_transformers_list, revert_patches_info, verbose=verbose
663696
)
664697

665-
if (
698+
if ( # vmap
666699
masking_utils
667700
and patch_transformers_list.patch_masking_utils
668701
and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -693,7 +726,7 @@ def torch_export_patches(
693726
"transformers.masking_utils.sdpa_mask"
694727
)
695728

696-
if (
729+
if ( # eager_mask
697730
masking_utils
698731
and patch_transformers_list.patch_masking_utils
699732
and hasattr(masking_utils, "eager_mask")
@@ -720,7 +753,7 @@ def torch_export_patches(
720753
"in ALL_MASK_ATTENTION_FUNCTIONS"
721754
)
722755

723-
if (
756+
if ( # sdpa_mask
724757
masking_utils
725758
and patch_transformers_list.patch_masking_utils
726759
and hasattr(masking_utils, "sdpa_mask")
@@ -740,6 +773,25 @@ def torch_export_patches(
740773
"in ALL_MASK_ATTENTION_FUNCTIONS"
741774
)
742775

776+
if ( # sdpa_attention_forward
777+
sdpa_attention is not None
778+
and modeling_utils is not None
779+
and hasattr(sdpa_attention, "sdpa_attention_forward")
780+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
781+
and hasattr(modeling_utils, "AttentionInterface")
782+
):
783+
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
784+
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
785+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
786+
f_sdpa_attention_forward
787+
)
788+
if verbose:
789+
print(
790+
"[torch_export_patches] restored "
791+
"transformers.integrations.sdpa_attention."
792+
"sdpa_attention_forward"
793+
)
794+
743795
########
744796
# caches
745797
########

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,56 @@ def common_eager_attention_forward(
12761276
return attn_output, attn_weights
12771277

12781278

1279+
def patched_sdpa_attention_forward(
1280+
module: torch.nn.Module,
1281+
query: torch.Tensor,
1282+
key: torch.Tensor,
1283+
value: torch.Tensor,
1284+
attention_mask: Optional[torch.Tensor],
1285+
dropout: float = 0.0,
1286+
scaling: Optional[float] = None,
1287+
is_causal: Optional[bool] = None,
1288+
**kwargs,
1289+
) -> tuple[torch.Tensor, None]:
1290+
"""[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
1291+
assert not kwargs.get("output_attentions", False), (
1292+
"`sdpa` attention does not support `output_attentions=True`."
1293+
" Please set your attention to `eager` if you want any of these features."
1294+
)
1295+
sdpa_kwargs = {}
1296+
if hasattr(module, "num_key_value_groups"):
1297+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
1298+
key = transformers.integrations.sdpa_attention.repeat_kv(
1299+
key, module.num_key_value_groups
1300+
)
1301+
value = transformers.integrations.sdpa_attention.repeat_kv(
1302+
value, module.num_key_value_groups
1303+
)
1304+
else:
1305+
sdpa_kwargs = {"enable_gqa": True}
1306+
1307+
if attention_mask is not None and attention_mask.ndim == 4:
1308+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1309+
1310+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1311+
# PATCHED: remove the test query.shape[2] > 1
1312+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1313+
is_causal = attention_mask is None and is_causal
1314+
1315+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1316+
query,
1317+
key,
1318+
value,
1319+
attn_mask=attention_mask,
1320+
dropout_p=dropout,
1321+
scale=scaling,
1322+
is_causal=is_causal,
1323+
**sdpa_kwargs,
1324+
)
1325+
attn_output = attn_output.transpose(1, 2).contiguous()
1326+
return attn_output, None
1327+
1328+
12791329
def patched_model_bart_eager_attention_forward(
12801330
module: torch.nn.Module,
12811331
query: torch.Tensor,

0 commit comments

Comments
 (0)