@@ -180,7 +180,7 @@ def register_additional_serialization_functions(
180180 unregister_cache_serialization (done , verbose = verbose )
181181
182182
183- def _patch_sympy (verbose : int , patch_details : PatchDetails ) -> Tuple [Callable , ...]:
183+ def _patch_sympy (verbose : int , patch_details : PatchDetails ) -> Tuple [Optional [ Callable ] , ...]:
184184 import sympy
185185
186186 f_sympy_name = getattr (sympy .core .numbers .IntegerConstant , "name" , None )
@@ -199,7 +199,7 @@ def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Callable, .
199199 return (f_sympy_name ,)
200200
201201
202- def _unpatch_sympy (verbose : int , f_sympy_name : Callable ):
202+ def _unpatch_sympy (verbose : int , f_sympy_name : Optional [ Callable ] ):
203203 # tracked by https://github.com/pytorch/pytorch/issues/143494
204204 import sympy
205205
@@ -218,7 +218,7 @@ def _patch_torch(
218218 patch_torch : int ,
219219 catch_constraints : bool ,
220220 stop_if_static : int ,
221- ) -> Tuple [Callable , ...]:
221+ ) -> Tuple [Optional [ Callable ] , ...]:
222222 import torch
223223 import torch .jit
224224 import torch ._export .non_strict_utils # produce_guards_and_solve_constraints
@@ -399,21 +399,21 @@ def _unpatch_torch(
399399 patch_torch : int ,
400400 catch_constraints : bool ,
401401 stop_if_static : int ,
402- f___constrain_user_specified_dimhint_range : Callable ,
403- f__broadcast_in_dim_meta : Callable ,
404- f__broadcast_shapes : Callable ,
405- f__check_input_constraints_for_graph : Callable ,
406- f__maybe_broadcast : Callable ,
407- f_broadcast_in_dim : Callable ,
408- f_infer_size : Callable ,
409- f_jit_isinstance : Callable ,
410- f_mark_static_address : Callable ,
411- f_produce_guards_and_solve_constraints : Callable ,
412- f_shape_env__check_frozen : Callable ,
413- f_shape_env__evaluate_expr : Callable ,
414- f_shape_env__log_guard : Callable ,
415- f_shape_env__set_replacement : Callable ,
416- f_vmap : Callable ,
402+ f___constrain_user_specified_dimhint_range : Optional [ Callable ] ,
403+ f__broadcast_in_dim_meta : Optional [ Callable ] ,
404+ f__broadcast_shapes : Optional [ Callable ] ,
405+ f__check_input_constraints_for_graph : Optional [ Callable ] ,
406+ f__maybe_broadcast : Optional [ Callable ] ,
407+ f_broadcast_in_dim : Optional [ Callable ] ,
408+ f_infer_size : Optional [ Callable ] ,
409+ f_jit_isinstance : Optional [ Callable ] ,
410+ f_mark_static_address : Optional [ Callable ] ,
411+ f_produce_guards_and_solve_constraints : Optional [ Callable ] ,
412+ f_shape_env__check_frozen : Optional [ Callable ] ,
413+ f_shape_env__evaluate_expr : Optional [ Callable ] ,
414+ f_shape_env__log_guard : Optional [ Callable ] ,
415+ f_shape_env__set_replacement : Optional [ Callable ] ,
416+ f_vmap : Optional [ Callable ] ,
417417):
418418 import torch
419419 import torch .jit
@@ -467,7 +467,9 @@ def _unpatch_torch(
467467 print ("[torch_export_patches] restored shape constraints" )
468468
469469
470- def _patch_transformers (verbose : int , patch_details : PatchDetails ) -> Tuple [Callable , ...]:
470+ def _patch_transformers (
471+ verbose : int , patch_details : PatchDetails
472+ ) -> Tuple [Optional [Callable ], ...]:
471473 import transformers
472474
473475 try :
@@ -504,6 +506,12 @@ def _patch_transformers(verbose: int, patch_details: PatchDetails) -> Tuple[Call
504506 f"sdpa_attention.sdpa_attention_forward={ sdpa_attention .sdpa_attention_forward } "
505507 )
506508
509+ f_transformers__vmap_for_bhqkv = None
510+ f_transformers_eager_mask = None
511+ f_transformers_sdpa_attention_forward = None
512+ f_transformers_sdpa_mask = None
513+ f_transformers_sdpa_mask_recent_torch = None
514+
507515 if ( # vmap
508516 masking_utils
509517 and patch_transformers_list .patch_masking_utils
@@ -649,12 +657,12 @@ def _patch_transformers(verbose: int, patch_details: PatchDetails) -> Tuple[Call
649657def _unpatch_transformers (
650658 verbose : int ,
651659 _patch_details : PatchDetails ,
652- f_transformers__vmap_for_bhqkv : Callable ,
653- f_transformers_eager_mask : Callable ,
654- f_transformers_sdpa_attention_forward : Callable ,
655- f_transformers_sdpa_mask : Callable ,
656- f_transformers_sdpa_mask_recent_torch : Callable ,
657- revert_patches_info : Callable ,
660+ f_transformers__vmap_for_bhqkv : Optional [ Callable ] ,
661+ f_transformers_eager_mask : Optional [ Callable ] ,
662+ f_transformers_sdpa_attention_forward : Optional [ Callable ] ,
663+ f_transformers_sdpa_mask : Optional [ Callable ] ,
664+ f_transformers_sdpa_mask_recent_torch : Optional [ Callable ] ,
665+ revert_patches_info : Optional [ Callable ] ,
658666):
659667
660668 try :
0 commit comments