Skip to content

Commit d357b28

Browse files
authored
Improves rewriting (#105)
* Improves rewriting * mypy * refacto * fixes * better doc * spelling * urls
1 parent 54e8373 commit d357b28

File tree

20 files changed

+498
-76
lines changed

20 files changed

+498
-76
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
timeout: 2
3232
retry_count# : 2
3333
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://github.com/huggingface/transformers/pull/36652
34-
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/,https://github.com/huggingface/transformers/
34+
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/,https://huggingface.co/,https://huggingface.co/
3535
# force_pass : true
3636

3737
- name: urls-checker-docs
@@ -43,5 +43,5 @@ jobs:
4343
timeout: 2
4444
retry_count# : 2
4545
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://codecov.io/,https://huggingface.co/
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/
4747
# force_pass : true

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`105`: more options to tune control flow rewriting
78
* :pr:`104`: add summarization task, add rewrite to command line validate
89
* :pr:`101`: first draft to rewrite loops
910
* :pr:`100`: implements a context to automatically rewrite methods or function with control flows

_doc/api/torch_export_patches/index.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ onnx_diagnostic.torch_export_patches
55
:maxdepth: 1
66
:caption: submodules
77

8+
onnx_export_errors
9+
onnx_export_serialization
810
patches/index
911
patch_expressions
1012
patch_inputs
1113
patch_module
12-
14+
patch_module_helper
1315

1416
.. automodule:: onnx_diagnostic.torch_export_patches
1517
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_errors
3+
=======================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_errors
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: torch_export_patches, register_additional_serialization_functions
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_serialization
3+
==============================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization
6+
:members:
7+
:no-undoc-members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_module_helper
3+
========================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module_helper
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,22 @@ def linkcode_resolve(domain, info):
125125
("py:class", "torch.utils._pytree.Context"),
126126
("py:class", "torch.utils._pytree.KeyEntry"),
127127
("py:class", "torch.utils._pytree.TreeSpec"),
128+
("py:class", "transformers.BartForConditionalGeneration"),
128129
("py:class", "transformers.LlamaConfig"),
129130
("py:class", "transformers.cache_utils.Cache"),
130131
("py:class", "transformers.cache_utils.DynamicCache"),
131132
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
132133
("py:class", "transformers.cache_utils.MambaCache"),
133134
("py:class", "transformers.cache_utils.SlidingWindowCache"),
134135
("py:class", "transformers.configuration_utils.PretrainedConfig"),
136+
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
135137
("py:func", "torch.export._draft_export.draft_export"),
136138
("py:func", "torch._export.tools.report_exportability"),
137139
("py:func", "torch.utils._pytree.register_pytree_node"),
138140
("py:meth", "huggingface_hub.HfApi.list_models"),
139141
("py:meth", "transformers.AutoConfig.from_pretrained"),
140142
("py:meth", "transformers.GenerationMixin.generate"),
143+
("py:meth", "transformers.models.bart.modeling_bart.BartEncoderLayer.forward"),
141144
("py:meth", "unittests.TestCase.subTest"),
142145
]
143146

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ It also implements tools to investigate, validate exported models (ExportedProgr
3535
:maxdepth: 1
3636
:caption: Contents
3737

38+
patches
3839
api/index
3940
cmds/index
4041
auto_examples/index

_doc/patches.rst

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
.. _l-patches-explained:
2+
3+
=================
4+
Patches Explained
5+
=================
6+
7+
Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
8+
implements four kinds of patches to make it easier to export a model, usually
9+
coming from :epkg:`transformers`.
10+
All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`.
11+
.. code-block:: python
12+
13+
with torch_export_patches(...) as f:
14+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
15+
16+
1. **torch fixes**:
17+
it disables some exceptions or improves some functions related to dynamic shapes
18+
until :epkg:`torch` addresses the issues
19+
(see `mostly exporter issues
20+
<https://github.com/pytorch/pytorch/issues?q=is%3Aissue%20state%3Aopen%20author%3Axadupre>`_)
21+
2. **transformers rewriting**:
22+
some methods are replaced with a version :func:`torch.export.export` can understand,
23+
some rewriting may migrate to :epkg:`transformers`, others are applied only
24+
at export time because it would make the implementation less efficient
25+
3. **cache serialization**: :func:`torch.export.export` needs to know how to
26+
serialize custom classes such as :class:`transformers.cache_utils.DynamicCache`
27+
4. **control flow rewriting**: control flow (if, for) cannot be exported as is,
28+
there is still some work to be done to automatically process them,
29+
this package offers some automated rewriting, but it is far from being perfect.
30+
31+
All of them are triggered by :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
32+
33+
.. code-block:: bash
34+
35+
python -m onnx_diagnostic validate \
36+
-m hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration \
37+
--run -v 1 --export onnx-dynamo -o dump_test --dtype float16 --device cuda
38+
39+
40+
All patches can be disabled with ``with torch_export_patches(patch=False)``.
41+
42+
torch fixes
43+
===========
44+
45+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_torch` and triggered with
46+
``with torch_export_patches(patch_sympy=True, patch_torch=True, catch_constraints=True, stop_if_static=1...)``.
47+
48+
It fixes some issues found while exporting model. Some of them might not be needed anymore.
49+
It improves shape broadcasting or inserts an exception every time a dynamic dimension
50+
becomes static (``stop_if_static=1``).
51+
52+
transformers rewriting
53+
======================
54+
55+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_transformers` and triggered with
56+
``with torch_export_patches(patch_transformers=True)``.
57+
58+
Every patched class is prefixed with ``patched_``. It contains two class attributes.
59+
``_PATCHES_`` contains the list of methods to replace.
60+
``_PATCHED_CLASS_`` is the class patched by this one.
61+
62+
.. code-block:: python
63+
64+
class patched_AttentionMaskConverter:
65+
"""
66+
Patches
67+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
68+
"""
69+
70+
# This method was fixed in 4.51 at least.
71+
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
72+
_PATCHED_CLASS_ = AttentionMaskConverter
73+
74+
The packages automatically parses this file to extract the patched methods.
75+
More can be added by populating the argument ``custom_patches``:
76+
``with torch_export_patches(custom_patches=[...])``.
77+
78+
Cache serialization
79+
===================
80+
81+
Implemented in :mod:`onnx_diagnostic.torch_export_patches.onnx_export_serialization`.
82+
Any custom classes manipulated by a model needs to be registered through
83+
``torch.utils._pytree.register_pytree_node`` or with
84+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`
85+
and triggered by ``with torch_export_patches(patch_transformers=True)``.
86+
This function does one class,
87+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`
88+
does all known classes.
89+
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister`
90+
or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`.
91+
92+
.. _l-control-flow-rewriting:
93+
94+
Control flow rewriting
95+
======================
96+
97+
This is an attempt to automatically rewrite control flow using :mod:`ast`.
98+
It is implemented in :mod:`onnx_diagnostic.torch_export_patches.patch_module` and
99+
triggered ``with torch_export_patches(rewrite=<instance of torch.nn.Module>)``.
100+
Option ``dump_rewriting=<folder>`` tells the function to dump all applied
101+
rewritings.
102+
103+
The following example contains the rewriting of method
104+
:meth:`transformers.models.bart.modeling_bart.BartEncoderLayer.forward`.
105+
The list of known rewriting to apply are returned by function
106+
:func:`onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting`
107+
and applied by function :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
108+
109+
While parsing the code, it is missing type information but this is known by
110+
:func:`torch.export.export`. Due to that, the automation usually needs manual tuning
111+
to filter out some tests (argument ``filter_node``) or pre/post processing
112+
(arguments ``pre_rewriter``, ``post_rewriter``) of function
113+
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
114+
115+
The main entry point is the context
116+
:func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`
117+
which rewrites and undoes the rewriting.
118+
For example, the model :class:`transformers.BartForConditionalGeneration`
119+
requires the following value for parameter ``rewrite``:
120+
121+
.. runpython::
122+
:showcode:
123+
124+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
125+
code_needing_rewriting,
126+
)
127+
128+
print(code_needing_rewriting("BartForConditionalGeneration"))
129+
130+
And that produces:
131+
132+
.. code-block:: diff
133+
134+
--- original
135+
+++ rewritten
136+
@@ -26,7 +26,6 @@
137+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
138+
hidden_states = residual + hidden_states
139+
hidden_states = self.self_attn_layer_norm(hidden_states)
140+
-
141+
residual = hidden_states
142+
hidden_states = self.activation_fn(self.fc1(hidden_states))
143+
hidden_states = nn.functional.dropout(
144+
@@ -37,15 +36,22 @@
145+
hidden_states = residual + hidden_states
146+
hidden_states = self.final_layer_norm(hidden_states)
147+
148+
- if hidden_states.dtype == torch.float16 and (
149+
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
150+
- ):
151+
+ def branch_cond_then_1(hidden_states):
152+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
153+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
154+
+ return hidden_states.clone()
155+
156+
+ def branch_cond_else_1(hidden_states):
157+
+ return hidden_states.clone()
158+
+
159+
+ hidden_states = torch.cond(
160+
+ hidden_states.dtype == torch.float16
161+
+ and torch.isinf(hidden_states).any() | torch.isnan(hidden_states).any(),
162+
+ branch_cond_then_1,
163+
+ branch_cond_else_1,
164+
+ [hidden_states],
165+
+ )
166+
outputs = (hidden_states,)
167+
-
168+
if output_attentions:
169+
- outputs += (attn_weights,)
170+
-
171+
+ outputs = outputs + (attn_weights,)
172+
return outputs

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ShapeFinder,
1414
RewriteControlFlow,
1515
)
16+
from onnx_diagnostic.torch_export_patches.patch_module_helper import ast_or_into_bitor
1617

1718

1819
class _ModelForATest(torch.nn.Module):
@@ -396,15 +397,25 @@ def forward(self, x, y):
396397
def test_rewrite_test_in_PLBartEncoderLayer(self):
397398
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer
398399

399-
rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
400+
def filter_node(node) -> bool:
401+
return isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
402+
403+
rewritten = transform_method(
404+
PLBartEncoderLayer.forward,
405+
verbose=self.verbose,
406+
filter_node=filter_node,
407+
pre_rewriter=ast_or_into_bitor,
408+
)
400409
self.assertIn(
401410
(
402411
"torch.cond(hidden_states.dtype == torch.float16 and "
403-
"(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()), "
412+
"torch.isinf(hidden_states).any()"
413+
" | torch.isnan(hidden_states).any(), "
404414
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
405415
),
406416
rewritten.code,
407417
)
418+
self.assertNotIn("torch.cond(output_attentions", rewritten.code)
408419

409420
@hide_stdout()
410421
def test_torch_export_patch_method_tuple(self):

0 commit comments

Comments
 (0)