Skip to content

Commit e579e50

Browse files
committed
fix
1 parent efa6880 commit e579e50

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def register_additional_serialization_functions(
180180
unregister_cache_serialization(done, verbose=verbose)
181181

182182

183-
def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Callable, ...]:
183+
def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Optional[Callable], ...]:
184184
import sympy
185185

186186
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
@@ -199,7 +199,7 @@ def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Callable, .
199199
return (f_sympy_name,)
200200

201201

202-
def _unpatch_sympy(verbose: int, f_sympy_name: Callable):
202+
def _unpatch_sympy(verbose: int, f_sympy_name: Optional[Callable]):
203203
# tracked by https://github.com/pytorch/pytorch/issues/143494
204204
import sympy
205205

@@ -218,7 +218,7 @@ def _patch_torch(
218218
patch_torch: int,
219219
catch_constraints: bool,
220220
stop_if_static: int,
221-
) -> Tuple[Callable, ...]:
221+
) -> Tuple[Optional[Callable], ...]:
222222
import torch
223223
import torch.jit
224224
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
@@ -399,21 +399,21 @@ def _unpatch_torch(
399399
patch_torch: int,
400400
catch_constraints: bool,
401401
stop_if_static: int,
402-
f___constrain_user_specified_dimhint_range: Callable,
403-
f__broadcast_in_dim_meta: Callable,
404-
f__broadcast_shapes: Callable,
405-
f__check_input_constraints_for_graph: Callable,
406-
f__maybe_broadcast: Callable,
407-
f_broadcast_in_dim: Callable,
408-
f_infer_size: Callable,
409-
f_jit_isinstance: Callable,
410-
f_mark_static_address: Callable,
411-
f_produce_guards_and_solve_constraints: Callable,
412-
f_shape_env__check_frozen: Callable,
413-
f_shape_env__evaluate_expr: Callable,
414-
f_shape_env__log_guard: Callable,
415-
f_shape_env__set_replacement: Callable,
416-
f_vmap: Callable,
402+
f___constrain_user_specified_dimhint_range: Optional[Callable],
403+
f__broadcast_in_dim_meta: Optional[Callable],
404+
f__broadcast_shapes: Optional[Callable],
405+
f__check_input_constraints_for_graph: Optional[Callable],
406+
f__maybe_broadcast: Optional[Callable],
407+
f_broadcast_in_dim: Optional[Callable],
408+
f_infer_size: Optional[Callable],
409+
f_jit_isinstance: Optional[Callable],
410+
f_mark_static_address: Optional[Callable],
411+
f_produce_guards_and_solve_constraints: Optional[Callable],
412+
f_shape_env__check_frozen: Optional[Callable],
413+
f_shape_env__evaluate_expr: Optional[Callable],
414+
f_shape_env__log_guard: Optional[Callable],
415+
f_shape_env__set_replacement: Optional[Callable],
416+
f_vmap: Optional[Callable],
417417
):
418418
import torch
419419
import torch.jit
@@ -467,7 +467,9 @@ def _unpatch_torch(
467467
print("[torch_export_patches] restored shape constraints")
468468

469469

470-
def _patch_transformers(verbose: int, patch_details: PatchDetails) -> Tuple[Callable, ...]:
470+
def _patch_transformers(
471+
verbose: int, patch_details: PatchDetails
472+
) -> Tuple[Optional[Callable], ...]:
471473
import transformers
472474

473475
try:
@@ -504,6 +506,12 @@ def _patch_transformers(verbose: int, patch_details: PatchDetails) -> Tuple[Call
504506
f"sdpa_attention.sdpa_attention_forward={sdpa_attention.sdpa_attention_forward}"
505507
)
506508

509+
f_transformers__vmap_for_bhqkv = None
510+
f_transformers_eager_mask = None
511+
f_transformers_sdpa_attention_forward = None
512+
f_transformers_sdpa_mask = None
513+
f_transformers_sdpa_mask_recent_torch = None
514+
507515
if ( # vmap
508516
masking_utils
509517
and patch_transformers_list.patch_masking_utils
@@ -649,12 +657,12 @@ def _patch_transformers(verbose: int, patch_details: PatchDetails) -> Tuple[Call
649657
def _unpatch_transformers(
650658
verbose: int,
651659
_patch_details: PatchDetails,
652-
f_transformers__vmap_for_bhqkv: Callable,
653-
f_transformers_eager_mask: Callable,
654-
f_transformers_sdpa_attention_forward: Callable,
655-
f_transformers_sdpa_mask: Callable,
656-
f_transformers_sdpa_mask_recent_torch: Callable,
657-
revert_patches_info: Callable,
660+
f_transformers__vmap_for_bhqkv: Optional[Callable],
661+
f_transformers_eager_mask: Optional[Callable],
662+
f_transformers_sdpa_attention_forward: Optional[Callable],
663+
f_transformers_sdpa_mask: Optional[Callable],
664+
f_transformers_sdpa_mask_recent_torch: Optional[Callable],
665+
revert_patches_info: Optional[Callable],
658666
):
659667

660668
try:

0 commit comments

Comments
 (0)