Skip to content

Commit 4ed5d8c

Browse files
committed
better doc
1 parent 1d7e449 commit 4ed5d8c

File tree

13 files changed

+247
-33
lines changed

13 files changed

+247
-33
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`105`: more options to tune control flow rewriting
78
* :pr:`104`: add summarization task, add rewrite to command line validate
89
* :pr:`101`: first draft to rewrite loops
910
* :pr:`100`: implements a context to automatically rewrite methods or function with control flows

_doc/api/torch_export_patches/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ onnx_diagnostic.torch_export_patches
55
:maxdepth: 1
66
:caption: submodules
77

8+
onnx_export_errors
9+
onnx_export_serialization
810
patches/index
911
patch_expressions
1012
patch_inputs
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_errors
3+
=======================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_errors
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: torch_export_patches, register_additional_serialization_functions
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_serialization
3+
==============================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization
6+
:members:
7+
:no-undoc-members:
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

2-
onnx_diagnostic.torch_export_patches.patch_module
3-
=================================================
2+
onnx_diagnostic.torch_export_patches.patch_module_helper
3+
========================================================
44

5-
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module_helper
66
:members:
77
:no-undoc-members:
8-
:exclude-members: torch_export_rewrite

_doc/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,22 @@ def linkcode_resolve(domain, info):
125125
("py:class", "torch.utils._pytree.Context"),
126126
("py:class", "torch.utils._pytree.KeyEntry"),
127127
("py:class", "torch.utils._pytree.TreeSpec"),
128+
("py:class", "transformers.BartForConditionalGeneration"),
128129
("py:class", "transformers.LlamaConfig"),
129130
("py:class", "transformers.cache_utils.Cache"),
130131
("py:class", "transformers.cache_utils.DynamicCache"),
131132
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
132133
("py:class", "transformers.cache_utils.MambaCache"),
133134
("py:class", "transformers.cache_utils.SlidingWindowCache"),
134135
("py:class", "transformers.configuration_utils.PretrainedConfig"),
136+
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
135137
("py:func", "torch.export._draft_export.draft_export"),
136138
("py:func", "torch._export.tools.report_exportability"),
137139
("py:func", "torch.utils._pytree.register_pytree_node"),
138140
("py:meth", "huggingface_hub.HfApi.list_models"),
139141
("py:meth", "transformers.AutoConfig.from_pretrained"),
140142
("py:meth", "transformers.GenerationMixin.generate"),
143+
("py:meth", "transformers.models.bart.modeling_bart.BartEncoderLayer.forward"),
141144
("py:meth", "unittests.TestCase.subTest"),
142145
]
143146

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ It also implements tools to investigate, validate exported models (ExportedProgr
3535
:maxdepth: 1
3636
:caption: Contents
3737

38+
patches
3839
api/index
3940
cmds/index
4041
auto_examples/index

_doc/patches.rst

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
.. _l-patches-explained:
2+
3+
=================
4+
Patches Explained
5+
=================
6+
7+
Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
8+
implements four kinds of patches to make it easier to export a model, usually
9+
coming from :epkg:`transformers`.
10+
All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`.
11+
.. code-block:: python
12+
13+
with torch_export_patches(...) as f:
14+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
15+
16+
1. **torch fixes**:
17+
it disables some exceptions or improves some functions related to dynamic shapes
18+
until :epkg:`torch` addresses the issues
19+
(see `mostly exporter issues
20+
<https://github.com/pytorch/pytorch/issues?q=is%3Aissue%20state%3Aopen%20author%3Axadupre>`_)
21+
2. **transformers rewriting**:
22+
some methods are replaced with a version :func:`torch.export.export` can understand,
23+
some rewriting may migrate to :epkg:`transformers`, others are applied only
24+
at export time because it would make the implementation less efficient
25+
3. **cache serialization**: :func:`torch.export.export` needs to know how to
26+
serialize custom classes such as :class:`transformers.cache_utils.DynamicCache`
27+
4. **control flow rewriting**: control flow (if, for) cannot be exported as is,
28+
there is still some work to be done to automatically process them,
29+
this package offers some automated rewriting, but it is far from being perfect.
30+
31+
All of them are triggered by :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
32+
33+
.. code-block:: bash
34+
35+
python -m onnx_diagnostic validate \
36+
-m hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration \
37+
--run -v 1 --export onnx-dynamo -o dump_test --dtype float16 --device cuda
38+
39+
40+
All patches can be disabled with ``with torch_export_patches(patch=False)``.
41+
42+
torch fixes
43+
===========
44+
45+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_torch` and triggered with
46+
``with torch_export_patches(patch_sympy=True, patch_torch=True, catch_constraints=True, stop_if_static=1...)``.
47+
48+
It fixes some issues found while exporting model. Some of them might not be needed anymore.
49+
It improves shape broadcasting or inserts an exception everytime a dynamic dimension
50+
becomes static (``stop_if_static=1``).
51+
52+
transformers rewriting
53+
======================
54+
55+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_transformers` and triggered with
56+
``with torch_export_patches(patch_transformers=True)``.
57+
58+
Every patched class is prefixed with ``patched_``. It contains two class attributes.
59+
``_PATCHES_`` contains the list of methods to replace.
60+
``_PATCHED_CLASS_`` is the class patched by this one.
61+
62+
.. code-block:: python
63+
64+
class patched_AttentionMaskConverter:
65+
"""
66+
Patches
67+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
68+
"""
69+
70+
# This method was fixed in 4.51 at least.
71+
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
72+
_PATCHED_CLASS_ = AttentionMaskConverter
73+
74+
The packages automatically parses this file to extract the patched methods.
75+
More can be added by populating the argument ``custom_patches``:
76+
``with torch_export_patches(custom_patches=[...])``.
77+
78+
Cache serialization
79+
===================
80+
81+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.onnx_export_serialization`.
82+
Any custom classes manipulated by a model needs to be registered through
83+
``torch.utils._pytree.register_pytree_node`` or with
84+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`
85+
and triggered by ``with torch_export_patches(patch_transformers=True)``.
86+
This function does one class,
87+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`
88+
does all known classes.
89+
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister`
90+
or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`.
91+
92+
.. _l-control-flow-rewriting:
93+
94+
Control flow rewriting
95+
======================
96+
97+
This is an attempt to automatically rewrite control flow using :mod:`ast`.
98+
It is implemented in :mod:`onnx_diagnostic.torch_export_patches.patch_module` and
99+
triggered ``with torch_export_patches(rewrite=<instance of torch.nn.Module>)``.
100+
Option ``dump_rewriting=<folder>`` tells the function to dump all applied
101+
rewritings.
102+
103+
The following example contains the rewriting of method
104+
:meth:`transformers.models.bart.modeling_bart.BartEncoderLayer.forward`.
105+
The list of known rewriting to apply are returned by function
106+
:func:`onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting`
107+
and applied by function :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
108+
109+
While parsing the code, it is missing type information but this is known by
110+
:func:`torch.export.export`. Due to that, the automation usually needs manual tuning
111+
to filter out some tests (argument ``filter_node``) or pre/post processing
112+
(arguments ``pre_rewriter``, ``post_rewriter``) of function
113+
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
114+
115+
The main entry point is the context
116+
:func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`
117+
which rewrites and undoes the rewriting.
118+
For example, the model :class:`transformers.BartForConditionalGeneration`
119+
requires the following value for parameter ``rewrite``:
120+
121+
.. runpython::
122+
:showcode:
123+
124+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
125+
code_needing_rewriting,
126+
)
127+
128+
print(code_needing_rewriting("BartForConditionalGeneration"))
129+
130+
And that produces:
131+
132+
.. code-block:: diff
133+
134+
--- original
135+
+++ rewritten
136+
@@ -26,7 +26,6 @@
137+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
138+
hidden_states = residual + hidden_states
139+
hidden_states = self.self_attn_layer_norm(hidden_states)
140+
-
141+
residual = hidden_states
142+
hidden_states = self.activation_fn(self.fc1(hidden_states))
143+
hidden_states = nn.functional.dropout(
144+
@@ -37,15 +36,22 @@
145+
hidden_states = residual + hidden_states
146+
hidden_states = self.final_layer_norm(hidden_states)
147+
148+
- if hidden_states.dtype == torch.float16 and (
149+
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
150+
- ):
151+
+ def branch_cond_then_1(hidden_states):
152+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
153+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
154+
+ return hidden_states.clone()
155+
156+
+ def branch_cond_else_1(hidden_states):
157+
+ return hidden_states.clone()
158+
+
159+
+ hidden_states = torch.cond(
160+
+ hidden_states.dtype == torch.float16
161+
+ and torch.isinf(hidden_states).any() | torch.isnan(hidden_states).any(),
162+
+ branch_cond_then_1,
163+
+ branch_cond_else_1,
164+
+ [hidden_states],
165+
+ )
166+
outputs = (hidden_states,)
167+
-
168+
if output_attentions:
169+
- outputs += (attn_weights,)
170+
-
171+
+ outputs = outputs + (attn_weights,)
172+
return outputs

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def register_flattening_functions(verbose: int = 0):
1616
This is needed whenever a model must be exported through
1717
:func:`torch.export.export`.
1818
"""
19-
from .onnx_export_serialization import _register_cache_serialization
19+
from .onnx_export_serialization import register_cache_serialization
2020

21-
return _register_cache_serialization(verbose=verbose)
21+
return register_cache_serialization(verbose=verbose)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import contextlib
22
from typing import Any, Callable, Dict, List, Optional
33
from .onnx_export_serialization import (
4-
_register_cache_serialization,
5-
_unregister_cache_serialization,
4+
register_cache_serialization,
5+
unregister_cache_serialization,
66
)
77
from .patches import patch_transformers as patch_transformers_list
88

@@ -85,11 +85,11 @@ def register_additional_serialization_functions(
8585
) -> Callable:
8686
"""The necessary modifications to run the fx Graph."""
8787
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
88-
done = _register_cache_serialization(verbose=verbose)
88+
done = register_cache_serialization(verbose=verbose)
8989
try:
9090
yield fct_callable
9191
finally:
92-
_unregister_cache_serialization(done, verbose=verbose)
92+
unregister_cache_serialization(done, verbose=verbose)
9393

9494

9595
@contextlib.contextmanager
@@ -107,6 +107,7 @@ def torch_export_patches(
107107
) -> Callable:
108108
"""
109109
Tries to bypass some situations :func:`torch.export.export` does not support.
110+
See also :ref:`l-patches-explained`.
110111
111112
:param patch_sympy: fix missing method ``name`` for IntegerConstant
112113
:param patch_torch: patches :epkg:`torch` with supported implementation
@@ -206,11 +207,11 @@ def torch_export_patches(
206207
pass
207208
elif not patch:
208209
fct_callable = lambda x: x # noqa: E731
209-
done = _register_cache_serialization(verbose=verbose)
210+
done = register_cache_serialization(verbose=verbose)
210211
try:
211212
yield fct_callable
212213
finally:
213-
_unregister_cache_serialization(done, verbose=verbose)
214+
unregister_cache_serialization(done, verbose=verbose)
214215
else:
215216
import torch
216217
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
@@ -226,7 +227,7 @@ def torch_export_patches(
226227
# caches
227228
########
228229

229-
cache_done = _register_cache_serialization(verbose=verbose)
230+
cache_done = register_cache_serialization(verbose=verbose)
230231

231232
#############
232233
# patch sympy
@@ -439,7 +440,7 @@ def torch_export_patches(
439440
# caches
440441
########
441442

442-
_unregister_cache_serialization(cache_done, verbose=verbose)
443+
unregister_cache_serialization(cache_done, verbose=verbose)
443444

444445

445446
def replacement_before_exporting(args: Any) -> Any:

0 commit comments

Comments
 (0)