Skip to content

Commit b0e8322

Browse files
committed
patch
1 parent c789e60 commit b0e8322

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ def torch_export_patches(
464464
except ImportError:
465465
sdpa_attention = None
466466

467+
try:
468+
import transformers.modeling_utils as modeling_utils
469+
except ImportError:
470+
modeling_utils = None
471+
467472
if verbose:
468473
import transformers
469474

@@ -558,8 +563,11 @@ def torch_export_patches(
558563
patch_transformers_list.patched_sdpa_mask_recent_torch
559564
)
560565

561-
if sdpa_attention is not None and hasattr( # sdpa_attention_forward
562-
sdpa_attention, "sdpa_attention_forward"
566+
if ( # sdpa_attention_forward
567+
sdpa_attention is not None
568+
and modeling_utils is not None
569+
and hasattr(sdpa_attention, "sdpa_attention_forward")
570+
and hasattr(modeling_utils, "AttentionInterface")
563571
):
564572
if verbose:
565573
print(
@@ -570,10 +578,10 @@ def torch_export_patches(
570578
sdpa_attention.sdpa_attention_forward = (
571579
patch_transformers_list.patched_sdpa_attention_forward
572580
)
573-
transformers.modeling_utils.sdpa_attention_forward = (
581+
modeling_utils.sdpa_attention_forward = (
574582
patch_transformers_list.patched_sdpa_attention_forward
575583
)
576-
transformers.modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
584+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
577585
patch_transformers_list.patched_sdpa_attention_forward
578586
)
579587

@@ -764,14 +772,15 @@ def torch_export_patches(
764772
"in ALL_MASK_ATTENTION_FUNCTIONS"
765773
)
766774

767-
if sdpa_attention is not None and hasattr( # sdpa_attention_forward
768-
sdpa_attention, "sdpa_attention_forward"
775+
if ( # sdpa_attention_forward
776+
sdpa_attention is not None
777+
and modeling_utils is not None
778+
and hasattr(sdpa_attention, "sdpa_attention_forward")
779+
and hasattr(modeling_utils, "AttentionInterface")
769780
):
770781
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
771-
transformers.modeling_utils.sdpa_attention_forward = (
772-
f_sdpa_attention_forward
773-
)
774-
transformers.modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
782+
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
783+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
775784
f_sdpa_attention_forward
776785
)
777786
if verbose:

0 commit comments

Comments
 (0)