Skip to content

Commit 977654c

Browse files
committed
fix test
1 parent 319ec3a commit 977654c

File tree

4 files changed

+96
-4
lines changed

4 files changed

+96
-4
lines changed

_doc/status/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ what works and what does not with :func:`torch.export.export`.
1111
exported_program_dynamic
1212
exporter_dynamic
1313
patches_coverage
14+
patches_diff
1415

1516
Examples checking about dynamic dimensions:
1617

_doc/status/patches_diff.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
.. _l-patch-diff:
2+
3+
============
4+
Patches Diff
5+
============
6+
7+
Patches are not always needed to export a LLM.
8+
Most of the time, only serialization function are needed to export
9+
a LLM with cache (``DynamicCache``, ...).
10+
Function :func:`register_additional_serialization_functions
11+
<onnx_diagnostic.torch_export_patches.register_additional_serialization_functions>`
12+
is enough in many cases.
13+
14+
.. code-block:: python
15+
16+
import torch
17+
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
18+
19+
with register_additional_serialization_functions(patch_transformers=True):
20+
ep = torch.export.export(...)
21+
22+
Function :func:`torch_export_patches
23+
<onnx_diagnostic.torch_export_patches.torch_export_patches>`
24+
helps fixing some issues for many models.
25+
26+
.. code-block:: python
27+
28+
import torch
29+
from onnx_diagnostic.torch_export_patches import torch_export_patches
30+
31+
with torch_export_patches(patch_transformers=True):
32+
ep = torch.export.export(...)
33+
34+
Class :class:`PatchDetails <onnx_diagnostic.torch_export_patches.patch_details.PatchDetails>`
35+
gives an example on how to retrieve the list of involded patches for a specific model.
36+
Those patches belongs to the following list which depends on transformers and
37+
pytorch versions.
38+
39+
.. runpython::
40+
:showcode:
41+
42+
import torch
43+
import transformers
44+
45+
print(torch.__version__, transformers.__version__)
46+
47+
Those two versions leads to the following list of patches.
48+
49+
.. runpython::
50+
:showcode:
51+
:rst:
52+
53+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
54+
from onnx_diagnostic.torch_export_patches import torch_export_patches
55+
56+
details = PatchDetails()
57+
with torch_export_patches(
58+
patch_transformers=True,
59+
patch_torch=True,
60+
patch_diffusers=True,
61+
patch_details=details,
62+
):
63+
pass
64+
for patch in details.patched:
65+
if patch.function_to_patch == patch.patch:
66+
continue
67+
rst = patch.format_diff(format="rst")
68+
print()
69+
print()
70+
print(rst)
71+
print()
72+
print()

_unittests/ut_torch_export_patches/test_patch_details.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from onnx_diagnostic.torch_export_patches import torch_export_patches
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
77
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails, PatchInfo
8-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patched_eager_mask
98
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
109

1110

@@ -35,6 +34,10 @@ def test_patch_details(self):
3534

3635
@requires_transformers("4.55")
3736
def test_patch_diff(self):
37+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
38+
patched_eager_mask,
39+
)
40+
3841
eager_mask = transformers.masking_utils.eager_mask
3942
self.assertEqual(eager_mask.__name__, "eager_mask")
4043
self.assertEqual(patched_eager_mask.__name__, "patched_eager_mask")

onnx_diagnostic/torch_export_patches/patch_details.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import difflib
22
import inspect
3+
import pprint
34
import re
45
import textwrap
56
from typing import Any, Dict, Callable, List, Optional, Tuple, Union
@@ -111,9 +112,13 @@ def format_diff(self, format: str = "raw") -> str:
111112
112113
import transformers
113114
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
114-
from onnx_diagnostic.torch_export_patches.patch_details import Patchinfo
115+
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
116+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
117+
patched_eager_mask,
118+
)
115119
116-
diff = Patchinfo(eager_mask, patched_eager_mask).format_diff(format="rst")
120+
eager_mask = transformers.masking_utils.eager_mask
121+
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
117122
print(diff)
118123
"""
119124
diff = self.make_diff()
@@ -146,7 +151,8 @@ class PatchDetails:
146151
"""
147152
This class is used to store patching information.
148153
This helps understanding which rewriting was applied to which
149-
method of functions.
154+
method of functions. Page :ref:`l-patch-diff` contains all the
155+
diff for all the implemented patches.
150156
151157
.. runpython::
152158
:showcode:
@@ -227,6 +233,7 @@ def patches_involded_in_graph(
227233
node_stack.append((node, stack))
228234

229235
patch_node = []
236+
patched_nodes = set()
230237
for patch, _f, source, interval in patches:
231238
exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
232239
reg = re.compile(exp)
@@ -246,6 +253,15 @@ def patches_involded_in_graph(
246253
and self.matching_pair(patch, node)
247254
):
248255
patch_node.append((patch, node))
256+
patched_nodes.add(id(node))
257+
258+
# checks all patches were discovered
259+
for node, _ in node_stack:
260+
assert id(node) in patched_nodes, (
261+
f"One node was patched but no patch was found:\n"
262+
f"node: {node.target}({','.join(map(str, node.args))}) -> {node.name}"
263+
f"\n--\n{pprint.pformat(node.meta)}"
264+
)
249265

250266
res = {}
251267
for patch, node in patch_node:

0 commit comments

Comments
 (0)