Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Change Logs
0.5.0
+++++

* :pr:`100`: implements a context to automatically rewrite methods or function with control flows
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator``
* :pr:`93`: introduce patched expression to get around annoying export issues
* :pr:`92`: support errors distribution in max_diff
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
* :pr:`95`: fixzq Scan implementation for ``OnnxruntimeEvaluator``
* :pr:`93`: introduces patched expressions to get around annoying export issues
* :pr:`92`: supports errors distribution in max_diff
* :pr:`91`: enables strings in ``guess_dynamic_shapes``
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)

Expand Down
16 changes: 16 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ Enlightening Examples
Snapshot of usefuls tools
+++++++++++++++++++++++++

**torch_export_patches**

.. code-block:: python

with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

**torch_export_rewrite**

.. code-block:: python

with torch_export_rewrite(rewrite=[Model.forward]) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

**string_type**

.. code-block:: python
Expand Down
2 changes: 2 additions & 0 deletions _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ onnx_diagnostic.torch_export_patches
:members:
:no-undoc-members:

.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_rewrite

.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches

.. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions
1 change: 1 addition & 0 deletions _doc/api/torch_export_patches/patch_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ onnx_diagnostic.torch_export_patches.patch_module
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
:members:
:no-undoc-members:
:exclude-members: torch_export_rewrite
22 changes: 22 additions & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,28 @@ Enlightening Examples
Some Usefuls Tools
==================

torch_export_patches
++++++++++++++++++++

See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.

.. code-block:: python

with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

torch_export_rewrite
++++++++++++++++++++

See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`.

.. code-block:: python

with torch_export_rewrite(rewrite=[Model.forward]) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

string_type
+++++++++++

Expand Down
83 changes: 81 additions & 2 deletions _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,28 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
from onnx_diagnostic.torch_export_patches.patch_module import (
transform_method,
inplace_add_parent,
)


class _ModelForATest(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1


def _single_forward(x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1


class TestPatchModule(ExtTestCase):
def test_parent(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -361,8 +377,71 @@ def test_rewrite_PLBartEncoderLayer(self):
),
rewritten.code,
)
print()
print(rewritten.code)

@hide_stdout()
def test_torch_export_patch_method_tuple(self):
class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_patches(rewrite=[(Model, "forward")], verbose=2):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
got = ep.module()(x, y)
self.assertEqualArray(expected, got)

@hide_stdout()
def test_torch_export_rewrite_method_tuple(self):
class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[(Model, "forward")], verbose=1):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
got = ep.module()(x, y)
self.assertEqualArray(expected, got)

def test_torch_export_rewrite_method_only(self):
model = _ModelForATest()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[_ModelForATest.forward], verbose=0):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
got = ep.module()(x, y)
self.assertEqualArray(expected, got)

@hide_stdout()
def test_torch_export_rewrite_function(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return _single_forward(x, y)

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[_single_forward], verbose=1):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
got = ep.module()(x, y)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/torch_export_patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
torch_export_patches,
register_additional_serialization_functions,
)
from .patch_module import torch_export_rewrite


# bypass_export_some_errors is the first name given to the patches.
Expand Down
34 changes: 29 additions & 5 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def torch_export_patches(
verbose: int = 0,
patch: bool = True,
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
rewrite: Optional[List[Callable]] = None,
) -> Callable:
"""
Tries to bypass some situations :func:`torch.export.export` does not support.
Expand All @@ -123,6 +124,12 @@ def torch_export_patches(
:param custom_patches: to apply custom patches,
every patched class must define static attributes
``_PATCHES_``, ``_PATCHED_CLASS_``
:param rewrite: list of methods to automatically rewrite
before exporting, methods with control flow need to be rewritten
before being exported if the execution path depends on the inputs,
this is done by function :func:`transform_method
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
its documentation provides possible values
:param verbose: to show which patches is applied

The list of available patches.
Expand All @@ -143,21 +150,21 @@ def torch_export_patches(

Examples:

::
.. code-block:: python

with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
onx = to_onnx(..., inputs, ...)

::
.. code-block:: python

with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
onx = torch.onnx.export(..., inputs, ...)

It can be used as well to fix the torch export:

::
.. code-block:: python

with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
Expand All @@ -166,7 +173,7 @@ def torch_export_patches(
When running the model through the exported program, only the
serialization functions need to be restored:

::
.. code-block:: python

with register_additional_serialization_functions() as modificator:
inputs = modificator(inputs)
Expand All @@ -176,7 +183,24 @@ def torch_export_patches(
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
"""
if not patch:
if rewrite:
from .patch_module import torch_export_rewrite

with torch_export_rewrite(rewrite=rewrite, verbose=verbose), torch_export_patches(
patch_sympy=patch_sympy,
patch_torch=patch_torch,
patch_transformers=patch_transformers,
catch_constraints=catch_constraints,
stop_if_static=stop_if_static,
verbose=verbose,
patch=patch,
custom_patches=custom_patches,
):
try:
yield
finally:
pass
elif not patch:
fct_callable = lambda x: x # noqa: E731
done = _register_cache_serialization(verbose=verbose)
try:
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/torch_export_patches/patch_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import transformers
from ..helpers import string_type
from ..helpers.cache_helper import make_dynamic_cache


def _process_cache(k: str, v):
Expand All @@ -16,6 +15,8 @@ def _process_cache(k: str, v):
and set(len(t) for t in v) == {2}
):
# A dynamicCache
from ..helpers.cache_helper import make_dynamic_cache

cache = make_dynamic_cache(v)
return cache
if isinstance(v, torch.Tensor):
Expand Down
Loading
Loading