Skip to content

Commit 9739f78

Browse files
committed
fix issues
1 parent 2ab34d7 commit 9739f78

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`100`: implements a context to automatically rewrite methods or function with control flows
78
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
8-
* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator``
9-
* :pr:`93`: introduce patched expression to get around annoying export issues
10-
* :pr:`92`: support errors distribution in max_diff
11-
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
9+
* :pr:`95`: fixzq Scan implementation for ``OnnxruntimeEvaluator``
10+
* :pr:`93`: introduces patched expressions to get around annoying export issues
11+
* :pr:`92`: supports errors distribution in max_diff
12+
* :pr:`91`: enables strings in ``guess_dynamic_shapes``
1213
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
1314
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
1415

_doc/api/torch_export_patches/patch_module.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ onnx_diagnostic.torch_export_patches.patch_module
55
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
66
:members:
77
:no-undoc-members:
8-
:exclude: torch_export_rewrite
8+
:exclude-members: torch_export_rewrite

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def torch_export_patches(
128128
before exporting, methods with control flow need to be rewritten
129129
before being exported if the execution path depends on the inputs,
130130
this is done by function :func:`transform_method
131-
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`
131+
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
132+
its documentation provides possible values
132133
:param verbose: to show which patches is applied
133134
134135
The list of available patches.

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,55 @@ def torch_export_rewrite(
530530
:param rewrite_methods: methods to rewrite, if not empty, the function may try
531531
to discover them, a method is defined by its class (a type) and its name
532532
if the class is local, by itself otherwise
533-
:param verbose: verbosity, up to 10, 10 shows the rewritten code
533+
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
534+
``verbose=1`` shows the rewritten function,
535+
``verbose=2`` shows the rewritten code as well
536+
537+
Example:
538+
539+
.. code-block:: python
540+
541+
class Model(torch.nn.Module):
542+
def forward(self, x, y):
543+
if x.sum() > 0:
544+
return x + y
545+
else:
546+
return torch.abs(x) + y + 1
547+
548+
model = Model()
549+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
550+
DYN = torch.export.Dim.DYNAMIC
551+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
552+
with torch_export_rewrite(rewrite_methods=[(Model, "forward")]):
553+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
554+
555+
If the method to rewrite is not local, then the following can be used:
556+
557+
.. code-block:: python
558+
559+
with torch_export_rewrite(rewrite_methods=[Model.forward]):
560+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
561+
562+
Functions (if not local) can also be rewritten:
563+
564+
.. code-block:: python
565+
566+
def outside(x, y):
567+
if x.sum() > 0:
568+
return x + y
569+
else:
570+
return torch.abs(x) + y + 1
571+
572+
class Model(torch.nn.Module):
573+
def forward(self, x, y):
574+
return outside(x, y)
575+
576+
model = Model()
577+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
578+
DYN = torch.export.Dim.DYNAMIC
579+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
580+
with torch_export_rewrite(rewrite_methods=[outside]):
581+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
534582
"""
535583
assert (
536584
rewrite_methods

0 commit comments

Comments
 (0)