Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.16
++++++

* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches

0.7.15
Expand Down
2 changes: 1 addition & 1 deletion _doc/examples/plot_export_hub_codellama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# %%
# The configuration.

print("config", get_pretrained_config(model_id))
print("config", get_pretrained_config(model_id, use_only_preinstalled=unit_test_going()))

# %%
# The task determines the set of inputs which needs
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_validate_whole_models1.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_m_validate_model_vit_model(self):
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)

@requires_torch("2.9")
@requires_torch("2.9.99")
@hide_stdout()
@ignore_warnings(FutureWarning)
@requires_transformers("4.55")
Expand Down
64 changes: 58 additions & 6 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,16 @@ def torch_export_patches(
except ImportError:
masking_utils = None

try:
import transformers.integrations.sdpa_attention as sdpa_attention
except ImportError:
sdpa_attention = None

try:
import transformers.modeling_utils as modeling_utils
except ImportError:
modeling_utils = None

if verbose:
import transformers

Expand All @@ -470,7 +480,7 @@ def torch_export_patches(
patch_transformers_list, verbose=verbose
)

if (
if ( # vmap
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "_vmap_for_bhqkv")
Expand Down Expand Up @@ -505,7 +515,7 @@ def torch_export_patches(
else:
f_transformers_sdpa_mask = None

if (
if ( # eager_mask
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "eager_mask")
Expand All @@ -532,7 +542,7 @@ def torch_export_patches(
patch_transformers_list.patched_eager_mask
)

if (
if ( # sdpa_mask
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "sdpa_mask")
Expand All @@ -553,6 +563,29 @@ def torch_export_patches(
patch_transformers_list.patched_sdpa_mask_recent_torch
)

if ( # sdpa_attention_forward
sdpa_attention is not None
and modeling_utils is not None
and hasattr(sdpa_attention, "sdpa_attention_forward")
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
and hasattr(modeling_utils, "AttentionInterface")
):
if verbose:
print(
"[torch_export_patches] patches "
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
)
f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
sdpa_attention.sdpa_attention_forward = (
patch_transformers_list.patched_sdpa_attention_forward
)
modeling_utils.sdpa_attention_forward = (
patch_transformers_list.patched_sdpa_attention_forward
)
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
patch_transformers_list.patched_sdpa_attention_forward
)

if custom_patches:
if verbose:
print("[torch_export_patches] applies custom patches")
Expand Down Expand Up @@ -662,7 +695,7 @@ def torch_export_patches(
patch_transformers_list, revert_patches_info, verbose=verbose
)

if (
if ( # vmap
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "_vmap_for_bhqkv")
Expand Down Expand Up @@ -693,7 +726,7 @@ def torch_export_patches(
"transformers.masking_utils.sdpa_mask"
)

if (
if ( # eager_mask
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "eager_mask")
Expand All @@ -720,7 +753,7 @@ def torch_export_patches(
"in ALL_MASK_ATTENTION_FUNCTIONS"
)

if (
if ( # sdpa_mask
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "sdpa_mask")
Expand All @@ -740,6 +773,25 @@ def torch_export_patches(
"in ALL_MASK_ATTENTION_FUNCTIONS"
)

if ( # sdpa_attention_forward
sdpa_attention is not None
and modeling_utils is not None
and hasattr(sdpa_attention, "sdpa_attention_forward")
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
and hasattr(modeling_utils, "AttentionInterface")
):
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
f_sdpa_attention_forward
)
if verbose:
print(
"[torch_export_patches] restored "
"transformers.integrations.sdpa_attention."
"sdpa_attention_forward"
)

########
# caches
########
Expand Down
50 changes: 50 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,56 @@ def common_eager_attention_forward(
return attn_output, attn_weights


def patched_sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
assert not kwargs.get("output_attentions", False), (
"`sdpa` attention does not support `output_attentions=True`."
" Please set your attention to `eager` if you want any of these features."
)
sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"):
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
key = transformers.integrations.sdpa_attention.repeat_kv(
key, module.num_key_value_groups
)
value = transformers.integrations.sdpa_attention.repeat_kv(
value, module.num_key_value_groups
)
else:
sdpa_kwargs = {"enable_gqa": True}

if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]

is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
# PATCHED: remove the test query.shape[2] > 1
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
is_causal = attention_mask is None and is_causal

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
**sdpa_kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None


def patched_model_bart_eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
Expand Down
Loading