11import contextlib
22import pprint
3- from typing import Any , Callable , Dict , Set
3+ from typing import Any , Callable , Dict , List , Optional , Set
44from .onnx_export_serialization import (
55 flatten_with_keys_dynamic_cache ,
66 flatten_dynamic_cache ,
1212from .patches import patch_transformers as patch_transformers_list
1313
1414
15- def patch_module (mod , verbose : int = 0 ) -> Dict [type , Dict [type , Callable ]]:
15+ def patch_module_or_classes (mod , verbose : int = 0 ) -> Dict [type , Dict [type , Callable ]]:
1616 """
1717 Applies all patches defined in classes prefixed by ``patched_``
1818 ``cls._PATCHED_CLASS_`` defines the class to patch,
1919 ``cls._PATCHES_`` defines the method to patch.
20- The returns information needs to be sent to :func:`unpatch_module `
20+ The returns information needs to be sent to :func:`unpatch_module_or_classes `
2121 to revert the changes.
22+
23+ :param mod: module of list of clsses to patch
24+ :param verbose: verbosity
25+ :return: patch info
2226 """
23- to_patch = []
24- for k in dir (mod ):
25- if k .startswith ("patched_" ):
26- v = getattr (mod , k )
27- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
28- to_patch .append (v )
27+ if isinstance (mod , list ):
28+ to_patch = mod
29+ name = "list"
30+ else :
31+ to_patch = []
32+ for k in dir (mod ):
33+ if k .startswith ("patched_" ):
34+ v = getattr (mod , k )
35+ if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
36+ to_patch .append (v )
37+ name = mod .__name__
2938
3039 res = {}
3140 for cls in to_patch :
3241 original = cls ._PATCHED_CLASS_
3342 methods = cls ._PATCHES_
3443 if verbose :
35- print (f"[patch_module ] { mod . __name__ } - { cls .__name__ } : { ', ' .join (methods )} " )
44+ print (f"[patch_module_or_classes ] { name } - { cls .__name__ } : { ', ' .join (methods )} " )
3645
3746 keep = {n : getattr (original , n , None ) for n in methods }
3847 for n in methods :
@@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
4251 return res
4352
4453
45- def unpatch_module (mod , info : Dict [type , Dict [type , Callable ]], verbose : int = 0 ):
46- """Reverts modification made by :func:`patch_module`."""
47- to_patch = []
48- for k in dir (mod ):
49- if k .startswith ("patched_" ):
50- v = getattr (mod , k )
51- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
52- to_patch .append (v )
54+ def unpatch_module_or_classes (mod , info : Dict [type , Dict [type , Callable ]], verbose : int = 0 ):
55+ """
56+ Reverts modification made by :func:`patch_module_or_classes`.
57+
58+ :param mod: module of list of clsses to patch
59+ :param verbose: verbosity
60+ """
61+ if isinstance (mod , list ):
62+ to_patch = mod
63+ name = "list"
64+ else :
65+ to_patch = []
66+ for k in dir (mod ):
67+ if k .startswith ("patched_" ):
68+ v = getattr (mod , k )
69+ if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
70+ to_patch .append (v )
71+ name = mod .__name__
5372 set_patch = set (to_patch )
5473
5574 for cls , methods in info .items ():
5675 assert cls in set_patch , f"No patch registered for { cls } in { mod } (found { set_patch } )"
5776 if verbose :
58- print (f"[unpatch_module ] { mod . __name__ } - { cls .__name__ } : { ', ' .join (methods )} " )
77+ print (f"[unpatch_module_or_classes ] { name } - { cls .__name__ } : { ', ' .join (methods )} " )
5978 original = cls ._PATCHED_CLASS_
6079 for n , v in methods .items ():
6180 if v is None :
@@ -237,6 +256,7 @@ def bypass_export_some_errors(
237256 stop_if_static : int = 0 ,
238257 verbose : int = 0 ,
239258 patch : bool = True ,
259+ custom_patches : Optional [List [type ["torch.nn.Module" ]]] = None , # noqa: F821
240260) -> Callable :
241261 """
242262 Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -255,6 +275,9 @@ def bypass_export_some_errors(
255275 issues
256276 :param patch: if False, disable all patches except the registration of
257277 serialization function
278+ :param custom_patches: to apply custom patches,
279+ every patched class must define static attributes
280+ ``_PATCHES_``, ``_PATCHED_CLASS_``
258281 :param verbose: to show which patches is applied
259282
260283 The list of available patches.
@@ -433,7 +456,16 @@ def bypass_export_some_errors(
433456 f"[bypass_export_some_errors] transformers.__version__="
434457 f"{ transformers .__version__ !r} "
435458 )
436- revert_patches_info = patch_module (patch_transformers_list , verbose = verbose )
459+ revert_patches_info = patch_module_or_classes (
460+ patch_transformers_list , verbose = verbose
461+ )
462+
463+ if custom_patches :
464+ if verbose :
465+ print ("[bypass_export_some_errors] applies custom patches" )
466+ revert_custom_patches_info = patch_module_or_classes (
467+ custom_patches , verbose = verbose
468+ )
437469
438470 ########
439471 # export
@@ -455,7 +487,6 @@ def bypass_export_some_errors(
455487 print ("[bypass_export_some_errors] remove patches" )
456488
457489 if patch_sympy :
458-
459490 # tracked by https://github.com/pytorch/pytorch/issues/143494
460491 if f_sympy_name :
461492 sympy .core .numbers .IntegerConstant .name = f_sympy_name
@@ -502,12 +533,23 @@ def bypass_export_some_errors(
502533 if verbose :
503534 print ("[bypass_export_some_errors] restored shape constraints" )
504535
536+ if custom_patches :
537+ if verbose :
538+ print ("[bypass_export_some_errors] unpatch custom patches" )
539+ unpatch_module_or_classes (
540+ custom_patches , revert_custom_patches_info , verbose = verbose
541+ )
542+
505543 ##############
506544 # transformers
507545 ##############
508546
509547 if patch_transformers :
510- unpatch_module (patch_transformers_list , revert_patches_info , verbose = verbose )
548+ if verbose :
549+ print ("[bypass_export_some_errors] unpatch transformers" )
550+ unpatch_module_or_classes (
551+ patch_transformers_list , revert_patches_info , verbose = verbose
552+ )
511553
512554 ########
513555 # caches
0 commit comments