1+ import functools
2+ import importlib
13import contextlib
2- from typing import Any , Callable , Dict , List , Optional
4+ import re
5+ from typing import Any , Callable , Dict , List , Optional , Tuple
36from .onnx_export_serialization import (
47 register_cache_serialization ,
58 unregister_cache_serialization ,
69)
710from .patches import patch_transformers as patch_transformers_list
811
912
10- def patch_module_or_classes ( mod , verbose : int = 0 ) -> Dict [ type , Dict [ type , Callable ]]:
13+ def get_function ( name : str ) -> Tuple [ "module" , "function" ]: # noqa: F821
1114 """
12- Applies all patches defined in classes prefixed by ``patched_``
13- ``cls._PATCHED_CLASS_`` defines the class to patch,
14- ``cls._PATCHES_`` defines the method to patch.
15- The returns information needs to be sent to :func:`unpatch_module_or_classes`
16- to revert the changes.
17-
18- :param mod: module of list of clsses to patch
19- :param verbose: verbosity
20- :return: patch info
15+ Returns the module and the function based on its name.
2116 """
17+ spl = name .split ("." )
18+ module_name = "." .join (spl [:- 1 ])
19+ fname = spl [- 1 ]
20+ mod = importlib .import_module (module_name )
21+ return mod , getattr (mod , fname )
22+
23+
24+ @functools .lru_cache
25+ def get_patches (mod , verbose : int = 0 ) -> Tuple [str , List [Any ]]:
26+ """Returns the list of patches to make for a specific module."""
2227 if isinstance (mod , list ):
2328 to_patch = mod
2429 name = "list"
@@ -29,10 +34,50 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
2934 v = getattr (mod , k )
3035 if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
3136 to_patch .append (v )
37+ else :
38+ # a function
39+ doc = v .__doc__
40+ if doc .startswith ("manual patch" ):
41+ continue
42+ reg = re .compile ("[[]patch:([a-z_A-Z.]+)[]]" )
43+ fall = reg .findall (doc )
44+ assert (
45+ len (fall ) == 1
46+ ), f"Unable to find patching information for { v } in \n { doc } "
47+ fmod , f = get_function (fall [0 ])
48+ to_patch .append ({"module" : fmod , "function" : f , "patch" : v })
49+
3250 name = mod .__name__
51+ return name , to_patch
52+
53+
54+ def patch_module_or_classes (mod , verbose : int = 0 ) -> Dict [type , Dict [type , Callable ]]:
55+ """
56+ Applies all patches defined in classes prefixed by ``patched_``
57+ ``cls._PATCHED_CLASS_`` defines the class to patch,
58+ ``cls._PATCHES_`` defines the method to patch.
59+ The returns information needs to be sent to :func:`unpatch_module_or_classes`
60+ to revert the changes.
61+
62+ :param mod: module of list of clsses to patch
63+ :param verbose: verbosity
64+ :return: patch info
65+ """
66+ name , to_patch = get_patches (mod , verbose )
3367
3468 res = {}
3569 for cls in to_patch :
70+ if isinstance (cls , dict ):
71+ # a function
72+ keep = {}
73+ original = cls ["module" ]
74+ f = cls ["function" ]
75+ res [f ] = f
76+ if verbose :
77+ print (f"[patch_module_or_classes] function: { original .__name__ } .{ f .__name__ } " )
78+ setattr (original , f .__name__ , cls ["patch" ])
79+ continue
80+
3681 original = cls ._PATCHED_CLASS_
3782 methods = cls ._PATCHES_
3883 if verbose :
@@ -53,30 +98,35 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
5398 :param mod: module of list of clsses to patch
5499 :param verbose: verbosity
55100 """
56- if isinstance (mod , list ):
57- to_patch = mod
58- name = "list"
59- else :
60- to_patch = []
61- for k in dir (mod ):
62- if k .startswith ("patched_" ):
63- v = getattr (mod , k )
64- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
65- to_patch .append (v )
66- name = mod .__name__
67- set_patch = set (to_patch )
101+ name , to_patch = get_patches (mod , verbose )
102+ set_patch_cls = {i for i in to_patch if not isinstance (i , dict )}
103+ dict_patch_fct = {i ["function" ]: i for i in to_patch if isinstance (i , dict )}
68104
69105 for cls , methods in info .items ():
70- assert cls in set_patch , f"No patch registered for { cls } in { mod } (found { set_patch } )"
106+ if cls in set_patch_cls :
107+ if verbose :
108+ print (
109+ f"[unpatch_module_or_classes] { name } .{ cls .__name__ } : { ', ' .join (methods )} "
110+ )
111+ original = cls ._PATCHED_CLASS_
112+ for n , v in methods .items ():
113+ if v is None :
114+ # The method did not exist. We remove it.
115+ delattr (original , n )
116+ else :
117+ setattr (original , n , v )
118+ continue
119+ assert cls in dict_patch_fct , (
120+ f"No patch registered for { cls } in { mod } "
121+ f"(found { set_patch_cls } and { set (dict_patch_fct )} )"
122+ )
123+ patch = dict_patch_fct [cls ]
71124 if verbose :
72- print (f"[unpatch_module_or_classes] { name } .{ cls .__name__ } : { ', ' .join (methods )} " )
73- original = cls ._PATCHED_CLASS_
74- for n , v in methods .items ():
75- if v is None :
76- # The method did not exist. We remove it.
77- delattr (original , n )
78- else :
79- setattr (original , n , v )
125+ print (
126+ f"[unpatch_module_or_classes] function "
127+ f"{ patch ['module' ].__name__ } .{ cls .__name__ } "
128+ )
129+ setattr (patch ["module" ], cls .__name__ , patch ["function" ])
80130
81131
82132@contextlib .contextmanager
0 commit comments