Skip to content

Commit 2ab34d7

Browse files
committed
Implement a context to rewrite method or functions
1 parent 1ede022 commit 2ab34d7

File tree

7 files changed

+195
-9
lines changed

7 files changed

+195
-9
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ onnx_diagnostic.torch_export_patches
1515
:members:
1616
:no-undoc-members:
1717

18+
.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_rewrite
19+
1820
.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches
1921

2022
.. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions

_doc/api/torch_export_patches/patch_module.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ onnx_diagnostic.torch_export_patches.patch_module
55
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
66
:members:
77
:no-undoc-members:
8+
:exclude: torch_export_rewrite

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,28 @@
44
import unittest
55
import torch
66
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
7+
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
78
from onnx_diagnostic.torch_export_patches.patch_module import (
89
transform_method,
910
inplace_add_parent,
1011
)
1112

1213

14+
class _ModelForATest(torch.nn.Module):
15+
def forward(self, x, y):
16+
if x.sum() > 0:
17+
return x + y
18+
else:
19+
return torch.abs(x) + y + 1
20+
21+
22+
def _single_forward(x, y):
23+
if x.sum() > 0:
24+
return x + y
25+
else:
26+
return torch.abs(x) + y + 1
27+
28+
1329
class TestPatchModule(ExtTestCase):
1430
def test_parent(self):
1531
class Model(torch.nn.Module):
@@ -361,8 +377,71 @@ def test_rewrite_PLBartEncoderLayer(self):
361377
),
362378
rewritten.code,
363379
)
364-
print()
365-
print(rewritten.code)
380+
381+
@hide_stdout()
382+
def test_torch_export_patch_method_tuple(self):
383+
class Model(torch.nn.Module):
384+
def forward(self, x, y):
385+
if x.sum() > 0:
386+
return x + y
387+
else:
388+
return torch.abs(x) + y + 1
389+
390+
model = Model()
391+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
392+
expected = model(x, y)
393+
DYN = torch.export.Dim.DYNAMIC
394+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
395+
with torch_export_patches(rewrite_methods=[(Model, "forward")], verbose=2):
396+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
397+
got = ep.module()(x, y)
398+
self.assertEqualArray(expected, got)
399+
400+
@hide_stdout()
401+
def test_torch_export_rewrite_method_tuple(self):
402+
class Model(torch.nn.Module):
403+
def forward(self, x, y):
404+
if x.sum() > 0:
405+
return x + y
406+
else:
407+
return torch.abs(x) + y + 1
408+
409+
model = Model()
410+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
411+
expected = model(x, y)
412+
DYN = torch.export.Dim.DYNAMIC
413+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
414+
with torch_export_rewrite(rewrite_methods=[(Model, "forward")], verbose=1):
415+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
416+
got = ep.module()(x, y)
417+
self.assertEqualArray(expected, got)
418+
419+
def test_torch_export_rewrite_method_only(self):
420+
model = _ModelForATest()
421+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
422+
expected = model(x, y)
423+
DYN = torch.export.Dim.DYNAMIC
424+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
425+
with torch_export_rewrite(rewrite_methods=[_ModelForATest.forward], verbose=0):
426+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
427+
got = ep.module()(x, y)
428+
self.assertEqualArray(expected, got)
429+
430+
@hide_stdout()
431+
def test_torch_export_rewrite_function(self):
432+
class Model(torch.nn.Module):
433+
def forward(self, x, y):
434+
return _single_forward(x, y)
435+
436+
model = Model()
437+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
438+
expected = model(x, y)
439+
DYN = torch.export.Dim.DYNAMIC
440+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
441+
with torch_export_rewrite(rewrite_methods=[_single_forward], verbose=1):
442+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
443+
got = ep.module()(x, y)
444+
self.assertEqualArray(expected, got)
366445

367446

368447
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
torch_export_patches,
33
register_additional_serialization_functions,
44
)
5+
from .patch_module import torch_export_rewrite
56

67

78
# bypass_export_some_errors is the first name given to the patches.

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def torch_export_patches(
102102
verbose: int = 0,
103103
patch: bool = True,
104104
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
105+
rewrite_methods: Optional[List[Callable]] = None,
105106
) -> Callable:
106107
"""
107108
Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -123,6 +124,11 @@ def torch_export_patches(
123124
:param custom_patches: to apply custom patches,
124125
every patched class must define static attributes
125126
``_PATCHES_``, ``_PATCHED_CLASS_``
127+
:param rewrite_methods: list of methods to automatically rewrite
128+
before exporting, methods with control flow need to be rewritten
129+
before being exported if the execution path depends on the inputs,
130+
this is done by function :func:`transform_method
131+
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`
126132
:param verbose: to show which patches is applied
127133
128134
The list of available patches.
@@ -143,21 +149,21 @@ def torch_export_patches(
143149
144150
Examples:
145151
146-
::
152+
.. code-block:: python
147153
148154
with torch_export_patches(patch_transformers=True) as modificator:
149155
inputs = modificator(inputs)
150156
onx = to_onnx(..., inputs, ...)
151157
152-
::
158+
.. code-block:: python
153159
154160
with torch_export_patches(patch_transformers=True) as modificator:
155161
inputs = modificator(inputs)
156162
onx = torch.onnx.export(..., inputs, ...)
157163
158164
It can be used as well to fix the torch export:
159165
160-
::
166+
.. code-block:: python
161167
162168
with torch_export_patches(patch_transformers=True) as modificator:
163169
inputs = modificator(inputs)
@@ -166,7 +172,7 @@ def torch_export_patches(
166172
When running the model through the exported program, only the
167173
serialization functions need to be restored:
168174
169-
::
175+
.. code-block:: python
170176
171177
with register_additional_serialization_functions() as modificator:
172178
inputs = modificator(inputs)
@@ -176,7 +182,26 @@ def torch_export_patches(
176182
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
177183
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
178184
"""
179-
if not patch:
185+
if rewrite_methods:
186+
from .patch_module import torch_export_rewrite
187+
188+
with torch_export_rewrite(
189+
rewrite_methods=rewrite_methods, verbose=verbose
190+
), torch_export_patches(
191+
patch_sympy=patch_sympy,
192+
patch_torch=patch_torch,
193+
patch_transformers=patch_transformers,
194+
catch_constraints=catch_constraints,
195+
stop_if_static=stop_if_static,
196+
verbose=verbose,
197+
patch=patch,
198+
custom_patches=custom_patches,
199+
):
200+
try:
201+
yield
202+
finally:
203+
pass
204+
elif not patch:
180205
fct_callable = lambda x: x # noqa: E731
181206
done = _register_cache_serialization(verbose=verbose)
182207
try:

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import transformers
55
from ..helpers import string_type
6-
from ..helpers.cache_helper import make_dynamic_cache
76

87

98
def _process_cache(k: str, v):
@@ -16,6 +15,8 @@ def _process_cache(k: str, v):
1615
and set(len(t) for t in v) == {2}
1716
):
1817
# A dynamicCache
18+
from ..helpers.cache_helper import make_dynamic_cache
19+
1920
cache = make_dynamic_cache(v)
2021
return cache
2122
if isinstance(v, torch.Tensor):

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import ast
22
import copy
3+
import contextlib
34
import inspect
45
import types
56
import textwrap
6-
from typing import Callable, Dict, List, Set, Optional
7+
import sys
8+
from typing import Callable, Dict, List, Set, Optional, Tuple, Union
79

810
NODE_TYPES = tuple(
911
getattr(ast, k)
@@ -515,3 +517,78 @@ def forward(self, x, y):
515517
if not isinstance(new_func, types.FunctionType):
516518
raise RuntimeError("Transformed function not found")
517519
return RewrittenMethod(new_tree, new_func)
520+
521+
522+
@contextlib.contextmanager
523+
def torch_export_rewrite(
524+
rewrite_methods: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0
525+
):
526+
"""
527+
Automatically rewrite the methods given in `rewrite_methods` to export
528+
control flows (test and loops).
529+
530+
:param rewrite_methods: methods to rewrite, if not empty, the function may try
531+
to discover them, a method is defined by its class (a type) and its name
532+
if the class is local, by itself otherwise
533+
:param verbose: verbosity, up to 10, 10 shows the rewritten code
534+
"""
535+
assert (
536+
rewrite_methods
537+
), "rewrite_methods is empty, automated discovery is not implemented yet"
538+
keep = {}
539+
for me in rewrite_methods:
540+
if isinstance(me, tuple):
541+
assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}"
542+
cls, name = me
543+
to_rewrite = getattr(cls, name)
544+
kind = "method"
545+
else:
546+
name = me.__qualname__
547+
spl = name.split(".")
548+
if len(spl) == 1:
549+
# This a function
550+
module = me.__module__
551+
if module in me.__globals__:
552+
mod = me.__globals__[module]
553+
else:
554+
assert module in sys.modules, (
555+
f"Cannot find module name {module!r} in sys.modules or "
556+
f"__globals__={sorted(me.__globals__)}"
557+
)
558+
mod = sys.modules[module]
559+
cls = mod
560+
name = name
561+
to_rewrite = me
562+
kind = "function"
563+
else:
564+
kind = "method"
565+
# This is a method
566+
assert len(spl) >= 2, (
567+
f"{me} is not method, its name {name!r} does not contain a class name, "
568+
f"dir(me)={dir(me)}"
569+
)
570+
cls_name = spl[-2]
571+
assert cls_name in me.__globals__, (
572+
f"Class name {cls_name!r} from method {name!r} "
573+
f"could not be found in set(me.__globals__)={sorted(me.__globals__)}"
574+
)
575+
cls = me.__globals__[cls_name]
576+
name = me.__name__
577+
to_rewrite = me
578+
assert hasattr(
579+
cls, name
580+
), f"Method {name!r} inferred form {me} was not found in class {cls}."
581+
assert (cls, name) not in keep, f"{kind} {me} cannot be rewritten twice."
582+
if verbose:
583+
print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
584+
keep[cls, name] = to_rewrite
585+
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0))
586+
setattr(cls, name, rewr.func)
587+
588+
try:
589+
yield
590+
finally:
591+
for (cls, name), me in keep.items():
592+
if verbose:
593+
print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
594+
setattr(cls, name, me)

0 commit comments

Comments
 (0)