Skip to content

Commit e495c6b

Browse files
committed
rename
1 parent 9739f78 commit e495c6b

File tree

5 files changed

+54
-20
lines changed

5 files changed

+54
-20
lines changed

README.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,22 @@ Enlightening Examples
8686
Snapshot of usefuls tools
8787
+++++++++++++++++++++++++
8888

89+
**torch_export_patches**
90+
91+
.. code-block:: python
92+
93+
with torch_export_patches(patch_transformers=True) as f:
94+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
95+
# ...
96+
97+
**torch_export_rewrite**
98+
99+
.. code-block:: python
100+
101+
with torch_export_rewrite(rewrite=[Model.forward]) as f:
102+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
103+
# ...
104+
89105
**string_type**
90106

91107
.. code-block:: python

_doc/index.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,28 @@ Enlightening Examples
8989
Some Usefuls Tools
9090
==================
9191

92+
torch_export_patches
93+
++++++++++++++++++++
94+
95+
See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
96+
97+
.. code-block:: python
98+
99+
with torch_export_patches(patch_transformers=True) as f:
100+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
101+
# ...
102+
103+
torch_export_rewrite
104+
++++++++++++++++++++
105+
106+
See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`.
107+
108+
.. code-block:: python
109+
110+
with torch_export_rewrite(rewrite=[Model.forward]) as f:
111+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
112+
# ...
113+
92114
string_type
93115
+++++++++++
94116

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def forward(self, x, y):
392392
expected = model(x, y)
393393
DYN = torch.export.Dim.DYNAMIC
394394
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
395-
with torch_export_patches(rewrite_methods=[(Model, "forward")], verbose=2):
395+
with torch_export_patches(rewrite=[(Model, "forward")], verbose=2):
396396
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
397397
got = ep.module()(x, y)
398398
self.assertEqualArray(expected, got)
@@ -411,7 +411,7 @@ def forward(self, x, y):
411411
expected = model(x, y)
412412
DYN = torch.export.Dim.DYNAMIC
413413
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
414-
with torch_export_rewrite(rewrite_methods=[(Model, "forward")], verbose=1):
414+
with torch_export_rewrite(rewrite=[(Model, "forward")], verbose=1):
415415
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
416416
got = ep.module()(x, y)
417417
self.assertEqualArray(expected, got)
@@ -422,7 +422,7 @@ def test_torch_export_rewrite_method_only(self):
422422
expected = model(x, y)
423423
DYN = torch.export.Dim.DYNAMIC
424424
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
425-
with torch_export_rewrite(rewrite_methods=[_ModelForATest.forward], verbose=0):
425+
with torch_export_rewrite(rewrite=[_ModelForATest.forward], verbose=0):
426426
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
427427
got = ep.module()(x, y)
428428
self.assertEqualArray(expected, got)
@@ -438,7 +438,7 @@ def forward(self, x, y):
438438
expected = model(x, y)
439439
DYN = torch.export.Dim.DYNAMIC
440440
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
441-
with torch_export_rewrite(rewrite_methods=[_single_forward], verbose=1):
441+
with torch_export_rewrite(rewrite=[_single_forward], verbose=1):
442442
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
443443
got = ep.module()(x, y)
444444
self.assertEqualArray(expected, got)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def torch_export_patches(
102102
verbose: int = 0,
103103
patch: bool = True,
104104
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
105-
rewrite_methods: Optional[List[Callable]] = None,
105+
rewrite: Optional[List[Callable]] = None,
106106
) -> Callable:
107107
"""
108108
Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -124,7 +124,7 @@ def torch_export_patches(
124124
:param custom_patches: to apply custom patches,
125125
every patched class must define static attributes
126126
``_PATCHES_``, ``_PATCHED_CLASS_``
127-
:param rewrite_methods: list of methods to automatically rewrite
127+
:param rewrite: list of methods to automatically rewrite
128128
before exporting, methods with control flow need to be rewritten
129129
before being exported if the execution path depends on the inputs,
130130
this is done by function :func:`transform_method
@@ -183,12 +183,10 @@ def torch_export_patches(
183183
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
184184
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
185185
"""
186-
if rewrite_methods:
186+
if rewrite:
187187
from .patch_module import torch_export_rewrite
188188

189-
with torch_export_rewrite(
190-
rewrite_methods=rewrite_methods, verbose=verbose
191-
), torch_export_patches(
189+
with torch_export_rewrite(rewrite=rewrite, verbose=verbose), torch_export_patches(
192190
patch_sympy=patch_sympy,
193191
patch_torch=patch_torch,
194192
patch_transformers=patch_transformers,

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,13 @@ def forward(self, x, y):
521521

522522
@contextlib.contextmanager
523523
def torch_export_rewrite(
524-
rewrite_methods: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0
524+
rewrite: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0
525525
):
526526
"""
527-
Automatically rewrite the methods given in `rewrite_methods` to export
527+
Automatically rewrite the methods given in `rewrite` to export
528528
control flows (test and loops).
529529
530-
:param rewrite_methods: methods to rewrite, if not empty, the function may try
530+
:param rewrite: methods of functions to rewrite, if not empty, the function may try
531531
to discover them, a method is defined by its class (a type) and its name
532532
if the class is local, by itself otherwise
533533
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
@@ -549,14 +549,14 @@ def forward(self, x, y):
549549
x, y = torch.rand((4, 5)), torch.rand((4, 5))
550550
DYN = torch.export.Dim.DYNAMIC
551551
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
552-
with torch_export_rewrite(rewrite_methods=[(Model, "forward")]):
552+
with torch_export_rewrite(rewrite=[(Model, "forward")]):
553553
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
554554
555555
If the method to rewrite is not local, then the following can be used:
556556
557557
.. code-block:: python
558558
559-
with torch_export_rewrite(rewrite_methods=[Model.forward]):
559+
with torch_export_rewrite(rewrite=[Model.forward]):
560560
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
561561
562562
Functions (if not local) can also be rewritten:
@@ -577,14 +577,12 @@ def forward(self, x, y):
577577
x, y = torch.rand((4, 5)), torch.rand((4, 5))
578578
DYN = torch.export.Dim.DYNAMIC
579579
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
580-
with torch_export_rewrite(rewrite_methods=[outside]):
580+
with torch_export_rewrite(rewrite=[outside]):
581581
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
582582
"""
583-
assert (
584-
rewrite_methods
585-
), "rewrite_methods is empty, automated discovery is not implemented yet"
583+
assert rewrite, "rewrite is empty, automated discovery is not implemented yet"
586584
keep = {}
587-
for me in rewrite_methods:
585+
for me in rewrite:
588586
if isinstance(me, tuple):
589587
assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}"
590588
cls, name = me

0 commit comments

Comments
 (0)