diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f145e583..4388d69e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,8 @@ Change Logs 0.3.0 +++++ -* :pr:`38`, uses the registered serialization functions when it is available +* :pr:`43`: uses custom patches +* :pr:`38`: uses the registered serialization functions when it is available * :pr:`30`, :pr:`31`: adds command to test a model id, validate the export * :pr:`29`: adds helpers to measure the memory peak and run benchmark on different processes diff --git a/_unittests/ut_torch_export_patches/test_patch_base_class.py b/_unittests/ut_torch_export_patches/test_patch_base_class.py index a2c0d91a..7dc4c708 100644 --- a/_unittests/ut_torch_export_patches/test_patch_base_class.py +++ b/_unittests/ut_torch_export_patches/test_patch_base_class.py @@ -1,5 +1,7 @@ import unittest -from onnx_diagnostic.ext_test_case import ExtTestCase +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors class TestPatchBaseClass(ExtTestCase): @@ -52,6 +54,28 @@ def ret(self, a): self.assertEqual(a.ret(4), 14) self.assertEqual(a.ok(), 13) + @hide_stdout() + def test_custom_patches(self): + class Model(torch.nn.Module): + def m1(self, x): + return x * x + + def forward(self, x): + return self.m1(x) + + class patched_Model: + _PATCHED_CLASS_ = Model + _PATCHES_ = ["m1"] + + def m1(self, x): + return x**3 + + model = Model() + x = torch.arange(4) + self.assertEqualArray(x * x, model(x)) + with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10): + self.assertEqualArray(x**3, model(x)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 67137994..9964a4da 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -1,6 +1,6 @@ import contextlib import pprint -from typing import Any, Callable, Dict, Set +from typing import Any, Callable, Dict, List, Optional, Set from .onnx_export_serialization import ( flatten_with_keys_dynamic_cache, flatten_dynamic_cache, @@ -12,27 +12,36 @@ from .patches import patch_transformers as patch_transformers_list -def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: +def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: """ Applies all patches defined in classes prefixed by ``patched_`` ``cls._PATCHED_CLASS_`` defines the class to patch, ``cls._PATCHES_`` defines the method to patch. - The returns information needs to be sent to :func:`unpatch_module` + The returns information needs to be sent to :func:`unpatch_module_or_classes` to revert the changes. + + :param mod: module of list of clsses to patch + :param verbose: verbosity + :return: patch info """ - to_patch = [] - for k in dir(mod): - if k.startswith("patched_"): - v = getattr(mod, k) - if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): - to_patch.append(v) + if isinstance(mod, list): + to_patch = mod + name = "list" + else: + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + name = mod.__name__ res = {} for cls in to_patch: original = cls._PATCHED_CLASS_ methods = cls._PATCHES_ if verbose: - print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}") keep = {n: getattr(original, n, None) for n in methods} for n in methods: @@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: return res -def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0): - """Reverts modification made by :func:`patch_module`.""" - to_patch = [] - for k in dir(mod): - if k.startswith("patched_"): - v = getattr(mod, k) - if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): - to_patch.append(v) +def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0): + """ + Reverts modification made by :func:`patch_module_or_classes`. + + :param mod: module of list of clsses to patch + :param verbose: verbosity + """ + if isinstance(mod, list): + to_patch = mod + name = "list" + else: + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + name = mod.__name__ set_patch = set(to_patch) for cls, methods in info.items(): assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})" if verbose: - print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}") original = cls._PATCHED_CLASS_ for n, v in methods.items(): if v is None: @@ -237,6 +256,7 @@ def bypass_export_some_errors( stop_if_static: int = 0, verbose: int = 0, patch: bool = True, + custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821 ) -> Callable: """ Tries to bypass some situations :func:`torch.export.export` does not support. @@ -255,6 +275,9 @@ def bypass_export_some_errors( issues :param patch: if False, disable all patches except the registration of serialization function + :param custom_patches: to apply custom patches, + every patched class must define static attributes + ``_PATCHES_``, ``_PATCHED_CLASS_`` :param verbose: to show which patches is applied The list of available patches. @@ -433,7 +456,16 @@ def bypass_export_some_errors( f"[bypass_export_some_errors] transformers.__version__=" f"{transformers.__version__!r}" ) - revert_patches_info = patch_module(patch_transformers_list, verbose=verbose) + revert_patches_info = patch_module_or_classes( + patch_transformers_list, verbose=verbose + ) + + if custom_patches: + if verbose: + print("[bypass_export_some_errors] applies custom patches") + revert_custom_patches_info = patch_module_or_classes( + custom_patches, verbose=verbose + ) ######## # export @@ -455,7 +487,6 @@ def bypass_export_some_errors( print("[bypass_export_some_errors] remove patches") if patch_sympy: - # tracked by https://github.com/pytorch/pytorch/issues/143494 if f_sympy_name: sympy.core.numbers.IntegerConstant.name = f_sympy_name @@ -502,12 +533,23 @@ def bypass_export_some_errors( if verbose: print("[bypass_export_some_errors] restored shape constraints") + if custom_patches: + if verbose: + print("[bypass_export_some_errors] unpatch custom patches") + unpatch_module_or_classes( + custom_patches, revert_custom_patches_info, verbose=verbose + ) + ############## # transformers ############## if patch_transformers: - unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose) + if verbose: + print("[bypass_export_some_errors] unpatch transformers") + unpatch_module_or_classes( + patch_transformers_list, revert_patches_info, verbose=verbose + ) ######## # caches