Skip to content

Commit 3e9f91b

Browse files
committed
doc
1 parent 764cbf6 commit 3e9f91b

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

_doc/patches.rst

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,17 @@ Every patched class is prefixed with ``patched_``. It contains two class attribu
8181
8282
The packages automatically parses this file to extract the patched methods.
8383
More can be added by populating the argument ``custom_patches``:
84-
``with torch_export_patches(custom_patches=[...])``.
84+
``with torch_export_patches(patch_transformers=True, custom_patches=[...])``.
85+
Here is the list of available patches:
86+
87+
.. runpython::
88+
:showcode:
89+
90+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p
91+
92+
for name, cls in p.__dict__.items():
93+
if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
94+
print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")
8595

8696
Cache serialization
8797
===================
@@ -96,6 +106,14 @@ This function does one class,
96106
does all known classes.
97107
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister`
98108
or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`.
109+
Here is the list of supported caches:
110+
111+
.. runpython::
112+
:showcode:
113+
114+
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
115+
116+
print("\n".join(sorted(p.serialization_functions())))
99117

100118
.. _l-control-flow-rewriting:
101119

@@ -184,3 +202,25 @@ We finally get:
184202
-
185203
+ outputs = outputs + (attn_weights,)
186204
return outputs
205+
206+
The locations where it has to be done:
207+
208+
.. runpython::
209+
:showcode:
210+
211+
import pprint
212+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
213+
known_transformers_rewritings_clamp_float16,
214+
)
215+
216+
pprint.pprint(known_transformers_rewritings_clamp_float16())
217+
218+
.. runpython::
219+
:showcode:
220+
221+
import pprint
222+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
223+
_rewrite_forward_clamp_float16,
224+
)
225+
226+
pprint.pprint(_rewrite_forward_clamp_float16())

_doc/status/patches_coverage.rst

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,25 @@ Half Automated Rewrites for Control Flows
3535
=========================================
3636

3737
The following script shows the list of methods automatically rewritten
38-
due to control flows.
38+
due to control flows. The same code is duplicated in many model classes.
39+
The number of fixes if much less than the number of classes to fix.
3940

4041
.. runpython::
4142
:showcode:
4243

43-
import onnx_diagnostic.torch_export_patches.patch_module_helper as p
44+
import pprint
45+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
46+
known_transformers_rewritings_clamp_float16,
47+
)
4448

45-
for name, f in p.__dict__.items():
46-
if name.startswith("_rewrite_"):
47-
print(f.__doc__)
49+
pprint.pprint(known_transformers_rewritings_clamp_float16())
50+
51+
.. runpython::
52+
:showcode:
53+
54+
import pprint
55+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
56+
_rewrite_forward_clamp_float16,
57+
)
58+
59+
pprint.pprint(_rewrite_forward_clamp_float16())

0 commit comments

Comments
 (0)