Skip to content

Commit 84817ab

Browse files
authored
Support custom patches (#43)
* custom patches * CHANGELOGS.rst
1 parent 80eab86 commit 84817ab

File tree

3 files changed

+91
-24
lines changed

3 files changed

+91
-24
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.3.0
55
+++++
66

7-
* :pr:`38`, uses the registered serialization functions when it is available
7+
* :pr:`43`: uses custom patches
8+
* :pr:`38`: uses the registered serialization functions when it is available
89
* :pr:`30`, :pr:`31`: adds command to test a model id, validate the export
910
* :pr:`29`: adds helpers to measure the memory peak and run benchmark
1011
on different processes

_unittests/ut_torch_export_patches/test_patch_base_class.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import ExtTestCase
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
35

46

57
class TestPatchBaseClass(ExtTestCase):
@@ -52,6 +54,28 @@ def ret(self, a):
5254
self.assertEqual(a.ret(4), 14)
5355
self.assertEqual(a.ok(), 13)
5456

57+
@hide_stdout()
58+
def test_custom_patches(self):
59+
class Model(torch.nn.Module):
60+
def m1(self, x):
61+
return x * x
62+
63+
def forward(self, x):
64+
return self.m1(x)
65+
66+
class patched_Model:
67+
_PATCHED_CLASS_ = Model
68+
_PATCHES_ = ["m1"]
69+
70+
def m1(self, x):
71+
return x**3
72+
73+
model = Model()
74+
x = torch.arange(4)
75+
self.assertEqualArray(x * x, model(x))
76+
with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10):
77+
self.assertEqualArray(x**3, model(x))
78+
5579

5680
if __name__ == "__main__":
5781
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import pprint
3-
from typing import Any, Callable, Dict, Set
3+
from typing import Any, Callable, Dict, List, Optional, Set
44
from .onnx_export_serialization import (
55
flatten_with_keys_dynamic_cache,
66
flatten_dynamic_cache,
@@ -12,27 +12,36 @@
1212
from .patches import patch_transformers as patch_transformers_list
1313

1414

15-
def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
15+
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
1616
"""
1717
Applies all patches defined in classes prefixed by ``patched_``
1818
``cls._PATCHED_CLASS_`` defines the class to patch,
1919
``cls._PATCHES_`` defines the method to patch.
20-
The returns information needs to be sent to :func:`unpatch_module`
20+
The returns information needs to be sent to :func:`unpatch_module_or_classes`
2121
to revert the changes.
22+
23+
:param mod: module of list of clsses to patch
24+
:param verbose: verbosity
25+
:return: patch info
2226
"""
23-
to_patch = []
24-
for k in dir(mod):
25-
if k.startswith("patched_"):
26-
v = getattr(mod, k)
27-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
28-
to_patch.append(v)
27+
if isinstance(mod, list):
28+
to_patch = mod
29+
name = "list"
30+
else:
31+
to_patch = []
32+
for k in dir(mod):
33+
if k.startswith("patched_"):
34+
v = getattr(mod, k)
35+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
36+
to_patch.append(v)
37+
name = mod.__name__
2938

3039
res = {}
3140
for cls in to_patch:
3241
original = cls._PATCHED_CLASS_
3342
methods = cls._PATCHES_
3443
if verbose:
35-
print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
44+
print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
3645

3746
keep = {n: getattr(original, n, None) for n in methods}
3847
for n in methods:
@@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
4251
return res
4352

4453

45-
def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
46-
"""Reverts modification made by :func:`patch_module`."""
47-
to_patch = []
48-
for k in dir(mod):
49-
if k.startswith("patched_"):
50-
v = getattr(mod, k)
51-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
52-
to_patch.append(v)
54+
def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
55+
"""
56+
Reverts modification made by :func:`patch_module_or_classes`.
57+
58+
:param mod: module of list of clsses to patch
59+
:param verbose: verbosity
60+
"""
61+
if isinstance(mod, list):
62+
to_patch = mod
63+
name = "list"
64+
else:
65+
to_patch = []
66+
for k in dir(mod):
67+
if k.startswith("patched_"):
68+
v = getattr(mod, k)
69+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
70+
to_patch.append(v)
71+
name = mod.__name__
5372
set_patch = set(to_patch)
5473

5574
for cls, methods in info.items():
5675
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
5776
if verbose:
58-
print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
77+
print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
5978
original = cls._PATCHED_CLASS_
6079
for n, v in methods.items():
6180
if v is None:
@@ -237,6 +256,7 @@ def bypass_export_some_errors(
237256
stop_if_static: int = 0,
238257
verbose: int = 0,
239258
patch: bool = True,
259+
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
240260
) -> Callable:
241261
"""
242262
Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -255,6 +275,9 @@ def bypass_export_some_errors(
255275
issues
256276
:param patch: if False, disable all patches except the registration of
257277
serialization function
278+
:param custom_patches: to apply custom patches,
279+
every patched class must define static attributes
280+
``_PATCHES_``, ``_PATCHED_CLASS_``
258281
:param verbose: to show which patches is applied
259282
260283
The list of available patches.
@@ -433,7 +456,16 @@ def bypass_export_some_errors(
433456
f"[bypass_export_some_errors] transformers.__version__="
434457
f"{transformers.__version__!r}"
435458
)
436-
revert_patches_info = patch_module(patch_transformers_list, verbose=verbose)
459+
revert_patches_info = patch_module_or_classes(
460+
patch_transformers_list, verbose=verbose
461+
)
462+
463+
if custom_patches:
464+
if verbose:
465+
print("[bypass_export_some_errors] applies custom patches")
466+
revert_custom_patches_info = patch_module_or_classes(
467+
custom_patches, verbose=verbose
468+
)
437469

438470
########
439471
# export
@@ -455,7 +487,6 @@ def bypass_export_some_errors(
455487
print("[bypass_export_some_errors] remove patches")
456488

457489
if patch_sympy:
458-
459490
# tracked by https://github.com/pytorch/pytorch/issues/143494
460491
if f_sympy_name:
461492
sympy.core.numbers.IntegerConstant.name = f_sympy_name
@@ -502,12 +533,23 @@ def bypass_export_some_errors(
502533
if verbose:
503534
print("[bypass_export_some_errors] restored shape constraints")
504535

536+
if custom_patches:
537+
if verbose:
538+
print("[bypass_export_some_errors] unpatch custom patches")
539+
unpatch_module_or_classes(
540+
custom_patches, revert_custom_patches_info, verbose=verbose
541+
)
542+
505543
##############
506544
# transformers
507545
##############
508546

509547
if patch_transformers:
510-
unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose)
548+
if verbose:
549+
print("[bypass_export_some_errors] unpatch transformers")
550+
unpatch_module_or_classes(
551+
patch_transformers_list, revert_patches_info, verbose=verbose
552+
)
511553

512554
########
513555
# caches

0 commit comments

Comments
 (0)