Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Change Logs
0.3.0
+++++

* :pr:`38`, uses the registered serialization functions when it is available
* :pr:`43`: uses custom patches
* :pr:`38`: uses the registered serialization functions when it is available
* :pr:`30`, :pr:`31`: adds command to test a model id, validate the export
* :pr:`29`: adds helpers to measure the memory peak and run benchmark
on different processes
Expand Down
26 changes: 25 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_base_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class TestPatchBaseClass(ExtTestCase):
Expand Down Expand Up @@ -52,6 +54,28 @@ def ret(self, a):
self.assertEqual(a.ret(4), 14)
self.assertEqual(a.ok(), 13)

@hide_stdout()
def test_custom_patches(self):
class Model(torch.nn.Module):
def m1(self, x):
return x * x

def forward(self, x):
return self.m1(x)

class patched_Model:
_PATCHED_CLASS_ = Model
_PATCHES_ = ["m1"]

def m1(self, x):
return x**3

model = Model()
x = torch.arange(4)
self.assertEqualArray(x * x, model(x))
with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10):
self.assertEqualArray(x**3, model(x))


if __name__ == "__main__":
unittest.main(verbosity=2)
86 changes: 64 additions & 22 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import pprint
from typing import Any, Callable, Dict, Set
from typing import Any, Callable, Dict, List, Optional, Set
from .onnx_export_serialization import (
flatten_with_keys_dynamic_cache,
flatten_dynamic_cache,
Expand All @@ -12,27 +12,36 @@
from .patches import patch_transformers as patch_transformers_list


def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
"""
Applies all patches defined in classes prefixed by ``patched_``
``cls._PATCHED_CLASS_`` defines the class to patch,
``cls._PATCHES_`` defines the method to patch.
The returns information needs to be sent to :func:`unpatch_module`
The returns information needs to be sent to :func:`unpatch_module_or_classes`
to revert the changes.
:param mod: module of list of clsses to patch
:param verbose: verbosity
:return: patch info
"""
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
if isinstance(mod, list):
to_patch = mod
name = "list"
else:
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
name = mod.__name__

res = {}
for cls in to_patch:
original = cls._PATCHED_CLASS_
methods = cls._PATCHES_
if verbose:
print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")

keep = {n: getattr(original, n, None) for n in methods}
for n in methods:
Expand All @@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
return res


def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
"""Reverts modification made by :func:`patch_module`."""
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
"""
Reverts modification made by :func:`patch_module_or_classes`.
:param mod: module of list of clsses to patch
:param verbose: verbosity
"""
if isinstance(mod, list):
to_patch = mod
name = "list"
else:
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
name = mod.__name__
set_patch = set(to_patch)

for cls, methods in info.items():
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
if verbose:
print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
original = cls._PATCHED_CLASS_
for n, v in methods.items():
if v is None:
Expand Down Expand Up @@ -237,6 +256,7 @@ def bypass_export_some_errors(
stop_if_static: int = 0,
verbose: int = 0,
patch: bool = True,
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
) -> Callable:
"""
Tries to bypass some situations :func:`torch.export.export` does not support.
Expand All @@ -255,6 +275,9 @@ def bypass_export_some_errors(
issues
:param patch: if False, disable all patches except the registration of
serialization function
:param custom_patches: to apply custom patches,
every patched class must define static attributes
``_PATCHES_``, ``_PATCHED_CLASS_``
:param verbose: to show which patches is applied
The list of available patches.
Expand Down Expand Up @@ -433,7 +456,16 @@ def bypass_export_some_errors(
f"[bypass_export_some_errors] transformers.__version__="
f"{transformers.__version__!r}"
)
revert_patches_info = patch_module(patch_transformers_list, verbose=verbose)
revert_patches_info = patch_module_or_classes(
patch_transformers_list, verbose=verbose
)

if custom_patches:
if verbose:
print("[bypass_export_some_errors] applies custom patches")
revert_custom_patches_info = patch_module_or_classes(
custom_patches, verbose=verbose
)

########
# export
Expand All @@ -455,7 +487,6 @@ def bypass_export_some_errors(
print("[bypass_export_some_errors] remove patches")

if patch_sympy:

# tracked by https://github.com/pytorch/pytorch/issues/143494
if f_sympy_name:
sympy.core.numbers.IntegerConstant.name = f_sympy_name
Expand Down Expand Up @@ -502,12 +533,23 @@ def bypass_export_some_errors(
if verbose:
print("[bypass_export_some_errors] restored shape constraints")

if custom_patches:
if verbose:
print("[bypass_export_some_errors] unpatch custom patches")
unpatch_module_or_classes(
custom_patches, revert_custom_patches_info, verbose=verbose
)

##############
# transformers
##############

if patch_transformers:
unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose)
if verbose:
print("[bypass_export_some_errors] unpatch transformers")
unpatch_module_or_classes(
patch_transformers_list, revert_patches_info, verbose=verbose
)

########
# caches
Expand Down
Loading