Skip to content

Commit 8052e04

Browse files
authored
doc (#106)
1 parent d357b28 commit 8052e04

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

_doc/patches.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
88
implements four kinds of patches to make it easier to export a model, usually
99
coming from :epkg:`transformers`.
1010
All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`.
11+
1112
.. code-block:: python
1213
1314
with torch_export_patches(...) as f:
@@ -121,13 +122,19 @@ requires the following value for parameter ``rewrite``:
121122
.. runpython::
122123
:showcode:
123124

125+
import pprint
124126
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
125127
code_needing_rewriting,
126128
)
127129

128-
print(code_needing_rewriting("BartForConditionalGeneration"))
130+
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
129131

130-
And that produces:
132+
This method has two tests. Only the first one needs to be rewritten.
133+
The second one manipulates tuple and the automated rewritten does not handle
134+
that because it cannot detect types. That explains why the parameter
135+
``filter_node`` is filled. Then, the first test includes a condition relying on ``or``
136+
which must be replaced by ``|``. That explains the parameter ``pre_rewriter``.
137+
We finally get:
131138

132139
.. code-block:: diff
133140

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
3030
.. runpython::
3131
:showcode:
3232
33+
import pprint
3334
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
3435
code_needing_rewriting,
3536
)
3637
37-
print(code_needing_rewriting("BartForConditionalGeneration"))
38+
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
3839
"""
3940
if cls_name in {
4041
"BartEncoderLayer",

0 commit comments

Comments
 (0)