@@ -81,7 +81,17 @@ Every patched class is prefixed with ``patched_``. It contains two class attribu
8181
8282 The packages automatically parses this file to extract the patched methods.
8383More can be added by populating the argument ``custom_patches ``:
84- ``with torch_export_patches(custom_patches=[...]) ``.
84+ ``with torch_export_patches(patch_transformers=True, custom_patches=[...]) ``.
85+ Here is the list of available patches:
86+
87+ .. runpython ::
88+ :showcode:
89+
90+ import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p
91+
92+ for name, cls in p.__dict__.items():
93+ if name.startswith("patched _") and hasattr(cls, "_PATCHES_"):
94+ print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")
8595
8696Cache serialization
8797===================
@@ -96,6 +106,14 @@ This function does one class,
96106does all known classes.
97107It can be undone with :func: `onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister `
98108or :func: `onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization `.
109+ Here is the list of supported caches:
110+
111+ .. runpython ::
112+ :showcode:
113+
114+ import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
115+
116+ print("\n ".join(sorted(p.serialization_functions())))
99117
100118.. _l-control-flow-rewriting :
101119
@@ -184,3 +202,25 @@ We finally get:
184202 -
185203 + outputs = outputs + (attn_weights,)
186204 return outputs
205+
206+ The locations where it has to be done:
207+
208+ .. runpython ::
209+ :showcode:
210+
211+ import pprint
212+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
213+ known_transformers_rewritings_clamp_float16,
214+ )
215+
216+ pprint.pprint(known_transformers_rewritings_clamp_float16())
217+
218+ .. runpython ::
219+ :showcode:
220+
221+ import pprint
222+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
223+ _rewrite_forward_clamp_float16,
224+ )
225+
226+ pprint.pprint(_rewrite_forward_clamp_float16())
0 commit comments