1+ import functools
2+ import importlib
13import contextlib
2- from typing import Any , Callable , Dict , List , Optional
4+ import re
5+ from typing import Any , Callable , Dict , List , Optional , Tuple
36from .onnx_export_serialization import (
47 register_cache_serialization ,
58 unregister_cache_serialization ,
69)
710from .patches import patch_transformers as patch_transformers_list
811
912
13+ def get_function (name : str ) -> Tuple [type , Callable ]:
14+ """Returns the module and the function based on its name."""
15+ spl = name .split ("." )
16+ module_name = "." .join (spl [:- 1 ])
17+ fname = spl [- 1 ]
18+ mod = importlib .import_module (module_name )
19+ return mod , getattr (mod , fname )
20+
21+
22+ @functools .lru_cache
23+ def get_patches (mod , verbose : int = 0 ) -> Tuple [str , List [Any ]]:
24+ """Returns the list of patches to make for a specific module."""
25+ to_patch = []
26+ for k in dir (mod ):
27+ if k .startswith ("patched_" ):
28+ v = getattr (mod , k )
29+ if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
30+ to_patch .append (v )
31+ else :
32+ # a function
33+ doc = v .__doc__ .lstrip ()
34+ if doc .startswith ("manual patch" ):
35+ continue
36+ reg = re .compile ("[[]patch:([a-z_A-Z.]+)[]]" )
37+ fall = reg .findall (doc )
38+ assert (
39+ len (fall ) == 1
40+ ), f"Unable to find patching information for { v } in \n { doc } "
41+ fmod , f = get_function (fall [0 ])
42+ to_patch .append ({"module" : fmod , "function" : f , "patch" : v })
43+
44+ name = mod .__name__
45+ return name , to_patch
46+
47+
1048def patch_module_or_classes (mod , verbose : int = 0 ) -> Dict [type , Dict [type , Callable ]]:
1149 """
1250 Applies all patches defined in classes prefixed by ``patched_``
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
2361 to_patch = mod
2462 name = "list"
2563 else :
26- to_patch = []
27- for k in dir (mod ):
28- if k .startswith ("patched_" ):
29- v = getattr (mod , k )
30- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
31- to_patch .append (v )
32- name = mod .__name__
64+ name , to_patch = get_patches (mod , verbose )
3365
3466 res = {}
3567 for cls in to_patch :
68+ if isinstance (cls , dict ):
69+ # a function
70+ keep = {}
71+ original = cls ["module" ]
72+ f = cls ["function" ]
73+ res [f ] = f
74+ if verbose :
75+ print (f"[patch_module_or_classes] function: { original .__name__ } .{ f .__name__ } " )
76+ setattr (original , f .__name__ , cls ["patch" ])
77+ continue
78+
3679 original = cls ._PATCHED_CLASS_
3780 methods = cls ._PATCHES_
3881 if verbose :
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
57100 to_patch = mod
58101 name = "list"
59102 else :
60- to_patch = []
61- for k in dir (mod ):
62- if k .startswith ("patched_" ):
63- v = getattr (mod , k )
64- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
65- to_patch .append (v )
66- name = mod .__name__
67- set_patch = set (to_patch )
103+ name , to_patch = get_patches (mod , verbose )
104+
105+ set_patch_cls = {i for i in to_patch if not isinstance (i , dict )}
106+ dict_patch_fct = {i ["function" ]: i for i in to_patch if isinstance (i , dict )}
68107
69108 for cls , methods in info .items ():
70- assert cls in set_patch , f"No patch registered for { cls } in { mod } (found { set_patch } )"
109+ if cls in set_patch_cls :
110+ if verbose :
111+ print (
112+ f"[unpatch_module_or_classes] { name } .{ cls .__name__ } : { ', ' .join (methods )} "
113+ )
114+ original = cls ._PATCHED_CLASS_
115+ for n , v in methods .items ():
116+ if v is None :
117+ # The method did not exist. We remove it.
118+ delattr (original , n )
119+ else :
120+ setattr (original , n , v )
121+ continue
122+ assert cls in dict_patch_fct , (
123+ f"No patch registered for { cls } in { mod } "
124+ f"(found { set_patch_cls } and { set (dict_patch_fct )} )"
125+ )
126+ patch = dict_patch_fct [cls ]
71127 if verbose :
72- print (f"[unpatch_module_or_classes] { name } .{ cls .__name__ } : { ', ' .join (methods )} " )
73- original = cls ._PATCHED_CLASS_
74- for n , v in methods .items ():
75- if v is None :
76- # The method did not exist. We remove it.
77- delattr (original , n )
78- else :
79- setattr (original , n , v )
128+ print (
129+ f"[unpatch_module_or_classes] function "
130+ f"{ patch ['module' ].__name__ } .{ cls .__name__ } "
131+ )
132+ setattr (patch ["module" ], cls .__name__ , patch ["function" ])
80133
81134
82135@contextlib .contextmanager
0 commit comments