@@ -464,6 +464,11 @@ def torch_export_patches(
464464 except ImportError :
465465 sdpa_attention = None
466466
467+ try :
468+ import transformers .modeling_utils as modeling_utils
469+ except ImportError :
470+ modeling_utils = None
471+
467472 if verbose :
468473 import transformers
469474
@@ -558,8 +563,11 @@ def torch_export_patches(
558563 patch_transformers_list .patched_sdpa_mask_recent_torch
559564 )
560565
561- if sdpa_attention is not None and hasattr ( # sdpa_attention_forward
562- sdpa_attention , "sdpa_attention_forward"
566+ if ( # sdpa_attention_forward
567+ sdpa_attention is not None
568+ and modeling_utils is not None
569+ and hasattr (sdpa_attention , "sdpa_attention_forward" )
570+ and hasattr (modeling_utils , "AttentionInterface" )
563571 ):
564572 if verbose :
565573 print (
@@ -570,10 +578,10 @@ def torch_export_patches(
570578 sdpa_attention .sdpa_attention_forward = (
571579 patch_transformers_list .patched_sdpa_attention_forward
572580 )
573- transformers . modeling_utils .sdpa_attention_forward = (
581+ modeling_utils .sdpa_attention_forward = (
574582 patch_transformers_list .patched_sdpa_attention_forward
575583 )
576- transformers . modeling_utils .AttentionInterface ._global_mapping ["sdpa" ] = (
584+ modeling_utils .AttentionInterface ._global_mapping ["sdpa" ] = (
577585 patch_transformers_list .patched_sdpa_attention_forward
578586 )
579587
@@ -764,14 +772,15 @@ def torch_export_patches(
764772 "in ALL_MASK_ATTENTION_FUNCTIONS"
765773 )
766774
767- if sdpa_attention is not None and hasattr ( # sdpa_attention_forward
768- sdpa_attention , "sdpa_attention_forward"
775+ if ( # sdpa_attention_forward
776+ sdpa_attention is not None
777+ and modeling_utils is not None
778+ and hasattr (sdpa_attention , "sdpa_attention_forward" )
779+ and hasattr (modeling_utils , "AttentionInterface" )
769780 ):
770781 sdpa_attention .sdpa_attention_forward = f_sdpa_attention_forward
771- transformers .modeling_utils .sdpa_attention_forward = (
772- f_sdpa_attention_forward
773- )
774- transformers .modeling_utils .AttentionInterface ._global_mapping ["sdpa" ] = (
782+ modeling_utils .sdpa_attention_forward = f_sdpa_attention_forward
783+ modeling_utils .AttentionInterface ._global_mapping ["sdpa" ] = (
775784 f_sdpa_attention_forward
776785 )
777786 if verbose :
0 commit comments