@@ -439,6 +439,28 @@ def torch_export_patches(
439439 f_transformers__vmap_for_bhqkv = masking_utils ._vmap_for_bhqkv
440440 masking_utils ._vmap_for_bhqkv = patch_transformers_list .patched__vmap_for_bhqkv
441441
442+ if verbose :
443+ print (
444+ "[torch_export_patches] patches "
445+ "transformers.masking_utils.sdpa_mask_recent_torch"
446+ )
447+ f_transformers_sdpa_mask_recent_torch = masking_utils .sdpa_mask_recent_torch
448+ masking_utils .sdpa_mask_recent_torch = (
449+ patch_transformers_list .patched_sdpa_mask_recent_torch
450+ )
451+ if masking_utils .sdpa_mask == f_transformers_sdpa_mask_recent_torch :
452+ if verbose :
453+ print (
454+ "[torch_export_patches] patches "
455+ "transformers.masking_utils.sdpa_mask"
456+ )
457+ f_transformers_sdpa_mask = masking_utils .sdpa_mask
458+ masking_utils .sdpa_mask = (
459+ patch_transformers_list .patched_sdpa_mask_recent_torch
460+ )
461+ else :
462+ f_transformers_sdpa_mask = None
463+
442464 if (
443465 masking_utils
444466 and patch_transformers_list .patch_masking_utils
@@ -456,10 +478,37 @@ def torch_export_patches(
456478 and masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["eager" ]
457479 == f_transformers_eager_mask
458480 ):
481+ if verbose :
482+ print (
483+ "[torch_export_patches] patches "
484+ "transformers.masking_utils.eager_mask "
485+ "in ALL_MASK_ATTENTION_FUNCTIONS"
486+ )
459487 masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["eager" ] = (
460488 patch_transformers_list .patched_eager_mask
461489 )
462490
491+ if (
492+ masking_utils
493+ and patch_transformers_list .patch_masking_utils
494+ and hasattr (masking_utils , "sdpa_mask" )
495+ and f_transformers_sdpa_mask is not None
496+ ):
497+ if verbose :
498+ print (
499+ "[torch_export_patches] patches "
500+ "transformers.masking_utils.sdpa_mask "
501+ "in ALL_MASK_ATTENTION_FUNCTIONS"
502+ )
503+ if (
504+ "sdpa" in masking_utils .ALL_MASK_ATTENTION_FUNCTIONS
505+ and masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ]
506+ == f_transformers_sdpa_mask
507+ ):
508+ masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ] = (
509+ patch_transformers_list .patched_sdpa_mask_recent_torch
510+ )
511+
463512 if custom_patches :
464513 if verbose :
465514 print ("[torch_export_patches] applies custom patches" )
@@ -568,19 +617,43 @@ def torch_export_patches(
568617 and hasattr (masking_utils , "_vmap_for_bhqkv" )
569618 ):
570619 masking_utils ._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
620+
571621 if verbose :
572622 print (
573623 "[torch_export_patches] restored "
574624 "transformers.masking_utils._vmap_for_bhqkv"
575625 )
576626
627+ masking_utils .sdpa_mask_recent_torch = (
628+ f_transformers_sdpa_mask_recent_torch
629+ )
630+
631+ if verbose :
632+ print (
633+ "[torch_export_patches] restored "
634+ "transformers.masking_utils.sdpa_mask_recent_torch"
635+ )
636+
637+ if f_transformers_sdpa_mask is not None :
638+ masking_utils .sdpa_mask = f_transformers_sdpa_mask
639+ if verbose :
640+ print (
641+ "[torch_export_patches] restored "
642+ "transformers.masking_utils.sdpa_mask"
643+ )
644+
577645 if (
578646 masking_utils
579647 and patch_transformers_list .patch_masking_utils
580648 and hasattr (masking_utils , "eager_mask" )
581649 ):
582650 f_transformers_eager_mask = masking_utils .eager_mask
583651 masking_utils .eager_mask = f_transformers_eager_mask
652+ if verbose :
653+ print (
654+ "[torch_export_patches] restored "
655+ "transformers.masking_utils.eager_mask"
656+ )
584657 if (
585658 "eager" in masking_utils .ALL_MASK_ATTENTION_FUNCTIONS
586659 and masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["eager" ]
@@ -589,11 +662,32 @@ def torch_export_patches(
589662 masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["eager" ] = (
590663 f_transformers_eager_mask
591664 )
592- if verbose :
593- print (
594- "[torch_export_patches] restored "
595- "transformers.masking_utils.eager_mask"
665+ if verbose :
666+ print (
667+ "[torch_export_patches] restored "
668+ "transformers.masking_utils.eager_mask "
669+ "in ALL_MASK_ATTENTION_FUNCTIONS"
670+ )
671+
672+ if (
673+ masking_utils
674+ and patch_transformers_list .patch_masking_utils
675+ and hasattr (masking_utils , "sdpa_mask" )
676+ ):
677+ if (
678+ "sdpa" in masking_utils .ALL_MASK_ATTENTION_FUNCTIONS
679+ and masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ]
680+ == patch_transformers_list .patched_sdpa_mask_recent_torch
681+ ):
682+ masking_utils .ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ] = (
683+ f_transformers_sdpa_mask
596684 )
685+ if verbose :
686+ print (
687+ "[torch_export_patches] restored "
688+ "transformers.masking_utils.sdpa_mask "
689+ "in ALL_MASK_ATTENTION_FUNCTIONS"
690+ )
597691
598692 ########
599693 # caches
0 commit comments