From 2ab34d77fdd82a6893dfdf1804b65e6a262d2a19 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 15 May 2025 12:26:48 +0200 Subject: [PATCH 1/3] Implement a context to rewrite method or functions --- _doc/api/torch_export_patches/index.rst | 2 + .../api/torch_export_patches/patch_module.rst | 1 + .../test_patch_module.py | 83 ++++++++++++++++++- .../torch_export_patches/__init__.py | 1 + .../onnx_export_errors.py | 35 ++++++-- .../torch_export_patches/patch_inputs.py | 3 +- .../torch_export_patches/patch_module.py | 79 +++++++++++++++++- 7 files changed, 195 insertions(+), 9 deletions(-) diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index 60c51d1d..df5d59de 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -15,6 +15,8 @@ onnx_diagnostic.torch_export_patches :members: :no-undoc-members: +.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_rewrite + .. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches .. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions diff --git a/_doc/api/torch_export_patches/patch_module.rst b/_doc/api/torch_export_patches/patch_module.rst index f3759993..ec306da0 100644 --- a/_doc/api/torch_export_patches/patch_module.rst +++ b/_doc/api/torch_export_patches/patch_module.rst @@ -5,3 +5,4 @@ onnx_diagnostic.torch_export_patches.patch_module .. automodule:: onnx_diagnostic.torch_export_patches.patch_module :members: :no-undoc-members: + :exclude: torch_export_rewrite diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 63e40591..3f073193 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -4,12 +4,28 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite from onnx_diagnostic.torch_export_patches.patch_module import ( transform_method, inplace_add_parent, ) +class _ModelForATest(torch.nn.Module): + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + +def _single_forward(x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + class TestPatchModule(ExtTestCase): def test_parent(self): class Model(torch.nn.Module): @@ -361,8 +377,71 @@ def test_rewrite_PLBartEncoderLayer(self): ), rewritten.code, ) - print() - print(rewritten.code) + + @hide_stdout() + def test_torch_export_patch_method_tuple(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + expected = model(x, y) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_patches(rewrite_methods=[(Model, "forward")], verbose=2): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) + + @hide_stdout() + def test_torch_export_rewrite_method_tuple(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + expected = model(x, y) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite_methods=[(Model, "forward")], verbose=1): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) + + def test_torch_export_rewrite_method_only(self): + model = _ModelForATest() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + expected = model(x, y) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite_methods=[_ModelForATest.forward], verbose=0): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) + + @hide_stdout() + def test_torch_export_rewrite_function(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return _single_forward(x, y) + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + expected = model(x, y) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite_methods=[_single_forward], verbose=1): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) if __name__ == "__main__": diff --git a/onnx_diagnostic/torch_export_patches/__init__.py b/onnx_diagnostic/torch_export_patches/__init__.py index 4ac9ae63..ea06e03b 100644 --- a/onnx_diagnostic/torch_export_patches/__init__.py +++ b/onnx_diagnostic/torch_export_patches/__init__.py @@ -2,6 +2,7 @@ torch_export_patches, register_additional_serialization_functions, ) +from .patch_module import torch_export_rewrite # bypass_export_some_errors is the first name given to the patches. diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 27cbfe14..f1078e73 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -102,6 +102,7 @@ def torch_export_patches( verbose: int = 0, patch: bool = True, custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821 + rewrite_methods: Optional[List[Callable]] = None, ) -> Callable: """ Tries to bypass some situations :func:`torch.export.export` does not support. @@ -123,6 +124,11 @@ def torch_export_patches( :param custom_patches: to apply custom patches, every patched class must define static attributes ``_PATCHES_``, ``_PATCHED_CLASS_`` + :param rewrite_methods: list of methods to automatically rewrite + before exporting, methods with control flow need to be rewritten + before being exported if the execution path depends on the inputs, + this is done by function :func:`transform_method + ` :param verbose: to show which patches is applied The list of available patches. @@ -143,13 +149,13 @@ def torch_export_patches( Examples: - :: + .. code-block:: python with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = to_onnx(..., inputs, ...) - :: + .. code-block:: python with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) @@ -157,7 +163,7 @@ def torch_export_patches( It can be used as well to fix the torch export: - :: + .. code-block:: python with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) @@ -166,7 +172,7 @@ def torch_export_patches( When running the model through the exported program, only the serialization functions need to be restored: - :: + .. code-block:: python with register_additional_serialization_functions() as modificator: inputs = modificator(inputs) @@ -176,7 +182,26 @@ def torch_export_patches( may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``. It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`. """ - if not patch: + if rewrite_methods: + from .patch_module import torch_export_rewrite + + with torch_export_rewrite( + rewrite_methods=rewrite_methods, verbose=verbose + ), torch_export_patches( + patch_sympy=patch_sympy, + patch_torch=patch_torch, + patch_transformers=patch_transformers, + catch_constraints=catch_constraints, + stop_if_static=stop_if_static, + verbose=verbose, + patch=patch, + custom_patches=custom_patches, + ): + try: + yield + finally: + pass + elif not patch: fct_callable = lambda x: x # noqa: E731 done = _register_cache_serialization(verbose=verbose) try: diff --git a/onnx_diagnostic/torch_export_patches/patch_inputs.py b/onnx_diagnostic/torch_export_patches/patch_inputs.py index 2c46a0d3..82f5e806 100644 --- a/onnx_diagnostic/torch_export_patches/patch_inputs.py +++ b/onnx_diagnostic/torch_export_patches/patch_inputs.py @@ -3,7 +3,6 @@ import torch import transformers from ..helpers import string_type -from ..helpers.cache_helper import make_dynamic_cache def _process_cache(k: str, v): @@ -16,6 +15,8 @@ def _process_cache(k: str, v): and set(len(t) for t in v) == {2} ): # A dynamicCache + from ..helpers.cache_helper import make_dynamic_cache + cache = make_dynamic_cache(v) return cache if isinstance(v, torch.Tensor): diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 7912ebf6..b15b9a7a 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -1,9 +1,11 @@ import ast import copy +import contextlib import inspect import types import textwrap -from typing import Callable, Dict, List, Set, Optional +import sys +from typing import Callable, Dict, List, Set, Optional, Tuple, Union NODE_TYPES = tuple( getattr(ast, k) @@ -515,3 +517,78 @@ def forward(self, x, y): if not isinstance(new_func, types.FunctionType): raise RuntimeError("Transformed function not found") return RewrittenMethod(new_tree, new_func) + + +@contextlib.contextmanager +def torch_export_rewrite( + rewrite_methods: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0 +): + """ + Automatically rewrite the methods given in `rewrite_methods` to export + control flows (test and loops). + + :param rewrite_methods: methods to rewrite, if not empty, the function may try + to discover them, a method is defined by its class (a type) and its name + if the class is local, by itself otherwise + :param verbose: verbosity, up to 10, 10 shows the rewritten code + """ + assert ( + rewrite_methods + ), "rewrite_methods is empty, automated discovery is not implemented yet" + keep = {} + for me in rewrite_methods: + if isinstance(me, tuple): + assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}" + cls, name = me + to_rewrite = getattr(cls, name) + kind = "method" + else: + name = me.__qualname__ + spl = name.split(".") + if len(spl) == 1: + # This a function + module = me.__module__ + if module in me.__globals__: + mod = me.__globals__[module] + else: + assert module in sys.modules, ( + f"Cannot find module name {module!r} in sys.modules or " + f"__globals__={sorted(me.__globals__)}" + ) + mod = sys.modules[module] + cls = mod + name = name + to_rewrite = me + kind = "function" + else: + kind = "method" + # This is a method + assert len(spl) >= 2, ( + f"{me} is not method, its name {name!r} does not contain a class name, " + f"dir(me)={dir(me)}" + ) + cls_name = spl[-2] + assert cls_name in me.__globals__, ( + f"Class name {cls_name!r} from method {name!r} " + f"could not be found in set(me.__globals__)={sorted(me.__globals__)}" + ) + cls = me.__globals__[cls_name] + name = me.__name__ + to_rewrite = me + assert hasattr( + cls, name + ), f"Method {name!r} inferred form {me} was not found in class {cls}." + assert (cls, name) not in keep, f"{kind} {me} cannot be rewritten twice." + if verbose: + print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}") + keep[cls, name] = to_rewrite + rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0)) + setattr(cls, name, rewr.func) + + try: + yield + finally: + for (cls, name), me in keep.items(): + if verbose: + print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}") + setattr(cls, name, me) From 9739f78eca5e4b6b57b264a6244d8146cb7d0696 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 15 May 2025 12:39:58 +0200 Subject: [PATCH 2/3] fix issues --- CHANGELOGS.rst | 9 ++-- .../api/torch_export_patches/patch_module.rst | 2 +- .../onnx_export_errors.py | 3 +- .../torch_export_patches/patch_module.py | 50 ++++++++++++++++++- 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index cb22f19b..dbb1c850 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,11 +4,12 @@ Change Logs 0.5.0 +++++ +* :pr:`100`: implements a context to automatically rewrite methods or function with control flows * :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward`` -* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator`` -* :pr:`93`: introduce patched expression to get around annoying export issues -* :pr:`92`: support errors distribution in max_diff -* :pr:`91`: enable strings in ``guess_dynamic_shapes`` +* :pr:`95`: fixzq Scan implementation for ``OnnxruntimeEvaluator`` +* :pr:`93`: introduces patched expressions to get around annoying export issues +* :pr:`92`: supports errors distribution in max_diff +* :pr:`91`: enables strings in ``guess_dynamic_shapes`` * :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models * :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test) diff --git a/_doc/api/torch_export_patches/patch_module.rst b/_doc/api/torch_export_patches/patch_module.rst index ec306da0..9fe6f9af 100644 --- a/_doc/api/torch_export_patches/patch_module.rst +++ b/_doc/api/torch_export_patches/patch_module.rst @@ -5,4 +5,4 @@ onnx_diagnostic.torch_export_patches.patch_module .. automodule:: onnx_diagnostic.torch_export_patches.patch_module :members: :no-undoc-members: - :exclude: torch_export_rewrite + :exclude-members: torch_export_rewrite diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index f1078e73..31672497 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -128,7 +128,8 @@ def torch_export_patches( before exporting, methods with control flow need to be rewritten before being exported if the execution path depends on the inputs, this is done by function :func:`transform_method - ` + `, + its documentation provides possible values :param verbose: to show which patches is applied The list of available patches. diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index b15b9a7a..a4e0d709 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -530,7 +530,55 @@ def torch_export_rewrite( :param rewrite_methods: methods to rewrite, if not empty, the function may try to discover them, a method is defined by its class (a type) and its name if the class is local, by itself otherwise - :param verbose: verbosity, up to 10, 10 shows the rewritten code + :param verbose: verbosity, up to 10, 10 shows the rewritten code, + ``verbose=1`` shows the rewritten function, + ``verbose=2`` shows the rewritten code as well + + Example: + + .. code-block:: python + + class Model(torch.nn.Module): + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite_methods=[(Model, "forward")]): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + + If the method to rewrite is not local, then the following can be used: + + .. code-block:: python + + with torch_export_rewrite(rewrite_methods=[Model.forward]): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + + Functions (if not local) can also be rewritten: + + .. code-block:: python + + def outside(x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + 1 + + class Model(torch.nn.Module): + def forward(self, x, y): + return outside(x, y) + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite_methods=[outside]): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) """ assert ( rewrite_methods From e495c6b3d98ed68e2358c19f1a1a467eed031328 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 15 May 2025 14:31:24 +0200 Subject: [PATCH 3/3] rename --- README.rst | 16 ++++++++++++++ _doc/index.rst | 22 +++++++++++++++++++ .../test_patch_module.py | 8 +++---- .../onnx_export_errors.py | 10 ++++----- .../torch_export_patches/patch_module.py | 18 +++++++-------- 5 files changed, 54 insertions(+), 20 deletions(-) diff --git a/README.rst b/README.rst index 23fde899..f9cdf089 100644 --- a/README.rst +++ b/README.rst @@ -86,6 +86,22 @@ Enlightening Examples Snapshot of usefuls tools +++++++++++++++++++++++++ +**torch_export_patches** + +.. code-block:: python + + with torch_export_patches(patch_transformers=True) as f: + ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) + # ... + +**torch_export_rewrite** + +.. code-block:: python + + with torch_export_rewrite(rewrite=[Model.forward]) as f: + ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) + # ... + **string_type** .. code-block:: python diff --git a/_doc/index.rst b/_doc/index.rst index bafeae27..7ebfd89a 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -89,6 +89,28 @@ Enlightening Examples Some Usefuls Tools ================== +torch_export_patches +++++++++++++++++++++ + +See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`. + +.. code-block:: python + + with torch_export_patches(patch_transformers=True) as f: + ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) + # ... + +torch_export_rewrite +++++++++++++++++++++ + +See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`. + +.. code-block:: python + + with torch_export_rewrite(rewrite=[Model.forward]) as f: + ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) + # ... + string_type +++++++++++ diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 3f073193..a5bbb953 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -392,7 +392,7 @@ def forward(self, x, y): expected = model(x, y) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_patches(rewrite_methods=[(Model, "forward")], verbose=2): + with torch_export_patches(rewrite=[(Model, "forward")], verbose=2): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) got = ep.module()(x, y) self.assertEqualArray(expected, got) @@ -411,7 +411,7 @@ def forward(self, x, y): expected = model(x, y) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_rewrite(rewrite_methods=[(Model, "forward")], verbose=1): + with torch_export_rewrite(rewrite=[(Model, "forward")], verbose=1): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) got = ep.module()(x, y) self.assertEqualArray(expected, got) @@ -422,7 +422,7 @@ def test_torch_export_rewrite_method_only(self): expected = model(x, y) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_rewrite(rewrite_methods=[_ModelForATest.forward], verbose=0): + with torch_export_rewrite(rewrite=[_ModelForATest.forward], verbose=0): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) got = ep.module()(x, y) self.assertEqualArray(expected, got) @@ -438,7 +438,7 @@ def forward(self, x, y): expected = model(x, y) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_rewrite(rewrite_methods=[_single_forward], verbose=1): + with torch_export_rewrite(rewrite=[_single_forward], verbose=1): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) got = ep.module()(x, y) self.assertEqualArray(expected, got) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 31672497..cc74e09e 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -102,7 +102,7 @@ def torch_export_patches( verbose: int = 0, patch: bool = True, custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821 - rewrite_methods: Optional[List[Callable]] = None, + rewrite: Optional[List[Callable]] = None, ) -> Callable: """ Tries to bypass some situations :func:`torch.export.export` does not support. @@ -124,7 +124,7 @@ def torch_export_patches( :param custom_patches: to apply custom patches, every patched class must define static attributes ``_PATCHES_``, ``_PATCHED_CLASS_`` - :param rewrite_methods: list of methods to automatically rewrite + :param rewrite: list of methods to automatically rewrite before exporting, methods with control flow need to be rewritten before being exported if the execution path depends on the inputs, this is done by function :func:`transform_method @@ -183,12 +183,10 @@ def torch_export_patches( may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``. It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`. """ - if rewrite_methods: + if rewrite: from .patch_module import torch_export_rewrite - with torch_export_rewrite( - rewrite_methods=rewrite_methods, verbose=verbose - ), torch_export_patches( + with torch_export_rewrite(rewrite=rewrite, verbose=verbose), torch_export_patches( patch_sympy=patch_sympy, patch_torch=patch_torch, patch_transformers=patch_transformers, diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index a4e0d709..55b7429e 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -521,13 +521,13 @@ def forward(self, x, y): @contextlib.contextmanager def torch_export_rewrite( - rewrite_methods: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0 + rewrite: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0 ): """ - Automatically rewrite the methods given in `rewrite_methods` to export + Automatically rewrite the methods given in `rewrite` to export control flows (test and loops). - :param rewrite_methods: methods to rewrite, if not empty, the function may try + :param rewrite: methods of functions to rewrite, if not empty, the function may try to discover them, a method is defined by its class (a type) and its name if the class is local, by itself otherwise :param verbose: verbosity, up to 10, 10 shows the rewritten code, @@ -549,14 +549,14 @@ def forward(self, x, y): x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_rewrite(rewrite_methods=[(Model, "forward")]): + with torch_export_rewrite(rewrite=[(Model, "forward")]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) If the method to rewrite is not local, then the following can be used: .. code-block:: python - with torch_export_rewrite(rewrite_methods=[Model.forward]): + with torch_export_rewrite(rewrite=[Model.forward]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) Functions (if not local) can also be rewritten: @@ -577,14 +577,12 @@ def forward(self, x, y): x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) - with torch_export_rewrite(rewrite_methods=[outside]): + with torch_export_rewrite(rewrite=[outside]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) """ - assert ( - rewrite_methods - ), "rewrite_methods is empty, automated discovery is not implemented yet" + assert rewrite, "rewrite is empty, automated discovery is not implemented yet" keep = {} - for me in rewrite_methods: + for me in rewrite: if isinstance(me, tuple): assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}" cls, name = me