Skip to content

Commit 905cd25

Browse files
committed
custom patch
1 parent b187624 commit 905cd25

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,26 @@ def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821
2424
@functools.lru_cache
2525
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
2626
"""Returns the list of patches to make for a specific module."""
27-
if isinstance(mod, list):
28-
to_patch = mod
29-
name = "list"
30-
else:
31-
to_patch = []
32-
for k in dir(mod):
33-
if k.startswith("patched_"):
34-
v = getattr(mod, k)
35-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
36-
to_patch.append(v)
37-
else:
38-
# a function
39-
doc = v.__doc__
40-
if doc.startswith("manual patch"):
41-
continue
42-
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
43-
fall = reg.findall(doc)
44-
assert (
45-
len(fall) == 1
46-
), f"Unable to find patching information for {v} in \n{doc}"
47-
fmod, f = get_function(fall[0])
48-
to_patch.append({"module": fmod, "function": f, "patch": v})
49-
50-
name = mod.__name__
27+
to_patch = []
28+
for k in dir(mod):
29+
if k.startswith("patched_"):
30+
v = getattr(mod, k)
31+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
32+
to_patch.append(v)
33+
else:
34+
# a function
35+
doc = v.__doc__
36+
if doc.startswith("manual patch"):
37+
continue
38+
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
39+
fall = reg.findall(doc)
40+
assert (
41+
len(fall) == 1
42+
), f"Unable to find patching information for {v} in \n{doc}"
43+
fmod, f = get_function(fall[0])
44+
to_patch.append({"module": fmod, "function": f, "patch": v})
45+
46+
name = mod.__name__
5147
return name, to_patch
5248

5349

@@ -63,7 +59,11 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
6359
:param verbose: verbosity
6460
:return: patch info
6561
"""
66-
name, to_patch = get_patches(mod, verbose)
62+
if isinstance(mod, list):
63+
to_patch = mod
64+
name = "list"
65+
else:
66+
name, to_patch = get_patches(mod, verbose)
6767

6868
res = {}
6969
for cls in to_patch:
@@ -98,7 +98,12 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
9898
:param mod: module of list of clsses to patch
9999
:param verbose: verbosity
100100
"""
101-
name, to_patch = get_patches(mod, verbose)
101+
if isinstance(mod, list):
102+
to_patch = mod
103+
name = "list"
104+
else:
105+
name, to_patch = get_patches(mod, verbose)
106+
102107
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
103108
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
104109

0 commit comments

Comments
 (0)