@@ -24,30 +24,26 @@ def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821
2424@functools .lru_cache
2525def get_patches (mod , verbose : int = 0 ) -> Tuple [str , List [Any ]]:
2626 """Returns the list of patches to make for a specific module."""
27- if isinstance (mod , list ):
28- to_patch = mod
29- name = "list"
30- else :
31- to_patch = []
32- for k in dir (mod ):
33- if k .startswith ("patched_" ):
34- v = getattr (mod , k )
35- if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
36- to_patch .append (v )
37- else :
38- # a function
39- doc = v .__doc__
40- if doc .startswith ("manual patch" ):
41- continue
42- reg = re .compile ("[[]patch:([a-z_A-Z.]+)[]]" )
43- fall = reg .findall (doc )
44- assert (
45- len (fall ) == 1
46- ), f"Unable to find patching information for { v } in \n { doc } "
47- fmod , f = get_function (fall [0 ])
48- to_patch .append ({"module" : fmod , "function" : f , "patch" : v })
49-
50- name = mod .__name__
27+ to_patch = []
28+ for k in dir (mod ):
29+ if k .startswith ("patched_" ):
30+ v = getattr (mod , k )
31+ if hasattr (v , "_PATCHED_CLASS_" ) and hasattr (v , "_PATCHES_" ):
32+ to_patch .append (v )
33+ else :
34+ # a function
35+ doc = v .__doc__
36+ if doc .startswith ("manual patch" ):
37+ continue
38+ reg = re .compile ("[[]patch:([a-z_A-Z.]+)[]]" )
39+ fall = reg .findall (doc )
40+ assert (
41+ len (fall ) == 1
42+ ), f"Unable to find patching information for { v } in \n { doc } "
43+ fmod , f = get_function (fall [0 ])
44+ to_patch .append ({"module" : fmod , "function" : f , "patch" : v })
45+
46+ name = mod .__name__
5147 return name , to_patch
5248
5349
@@ -63,7 +59,11 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
6359 :param verbose: verbosity
6460 :return: patch info
6561 """
66- name , to_patch = get_patches (mod , verbose )
62+ if isinstance (mod , list ):
63+ to_patch = mod
64+ name = "list"
65+ else :
66+ name , to_patch = get_patches (mod , verbose )
6767
6868 res = {}
6969 for cls in to_patch :
@@ -98,7 +98,12 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
9898 :param mod: module of list of clsses to patch
9999 :param verbose: verbosity
100100 """
101- name , to_patch = get_patches (mod , verbose )
101+ if isinstance (mod , list ):
102+ to_patch = mod
103+ name = "list"
104+ else :
105+ name , to_patch = get_patches (mod , verbose )
106+
102107 set_patch_cls = {i for i in to_patch if not isinstance (i , dict )}
103108 dict_patch_fct = {i ["function" ]: i for i in to_patch if isinstance (i , dict )}
104109
0 commit comments