Skip to content

Commit 200869a

Browse files
authored
Stores patch details into PatchDetails (#282)
* add patch details * patches * patch * better doc * fix test * mypy * fix * refactor patches * fix * mypy * fix * fix sdpa * update CI * fix * tweak ci
1 parent ab5832b commit 200869a

File tree

13 files changed

+1249
-478
lines changed

13 files changed

+1249
-478
lines changed

.github/workflows/ci.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,17 @@ jobs:
149149
- name: run tests
150150
run: |
151151
pip install pytest
152-
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py
152+
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py --ignore _unittests/ut_torch_models/test_validate_whole_models*.py
153+
154+
- name: test models
155+
run: |
156+
echo "----"
157+
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models1.py
158+
echo "----"
159+
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models2.py
160+
echo "----"
161+
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models3.py
162+
echo "----"
153163
154164
# - name: run backend tests python
155165
# run: PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.0
55
+++++
66

7+
* :pr:`282`: add tools to understand better which functions were patched
78
* :pr:`280`: fixes patches for sdpa_attention_forward for different version of transformers
89
* :pr:`278`: implements ``onnx_generate_with_genai``
910
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ onnx_diagnostic.torch_export_patches
99
onnx_export_errors
1010
onnx_export_serialization
1111
patches/index
12+
patch_details
1213
patch_expressions
1314
patch_inputs
1415
patch_module
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_details
3+
==================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_details
6+
:members:

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def linkcode_resolve(domain, info):
214214
epkg_dictionary = {
215215
"aten functions": "https://pytorch.org/cppdocs/api/namespace_at.html#functions",
216216
"azure pipeline": "https://azure.microsoft.com/en-us/products/devops/pipelines",
217+
"black": "https://github.com/psf/black",
217218
"Custom Backends": "https://docs.pytorch.org/docs/stable/torch.compiler_custom_backends.html",
218219
"diffusers": "https://github.com/huggingface/diffusers",
219220
"DOT": "https://graphviz.org/doc/info/lang.html",

_doc/status/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ what works and what does not with :func:`torch.export.export`.
1111
exported_program_dynamic
1212
exporter_dynamic
1313
patches_coverage
14+
patches_diff
1415

1516
Examples checking about dynamic dimensions:
1617

_doc/status/patches_diff.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
.. _l-patch-diff:
2+
3+
============
4+
Patches Diff
5+
============
6+
7+
Patches are not always needed to export a LLM.
8+
Most of the time, only serialization function are needed to export
9+
a LLM with cache (``DynamicCache``, ...).
10+
Function :func:`register_additional_serialization_functions
11+
<onnx_diagnostic.torch_export_patches.register_additional_serialization_functions>`
12+
is enough in many cases.
13+
14+
.. code-block:: python
15+
16+
import torch
17+
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
18+
19+
with register_additional_serialization_functions(patch_transformers=True):
20+
ep = torch.export.export(...)
21+
22+
Function :func:`torch_export_patches
23+
<onnx_diagnostic.torch_export_patches.torch_export_patches>`
24+
helps fixing some issues for many models.
25+
26+
.. code-block:: python
27+
28+
import torch
29+
from onnx_diagnostic.torch_export_patches import torch_export_patches
30+
31+
with torch_export_patches(patch_transformers=True):
32+
ep = torch.export.export(...)
33+
34+
Class :class:`PatchDetails <onnx_diagnostic.torch_export_patches.patch_details.PatchDetails>`
35+
gives an example on how to retrieve the list of involded patches for a specific model.
36+
Those patches belongs to the following list which depends on transformers and
37+
pytorch versions.
38+
39+
.. runpython::
40+
:showcode:
41+
42+
import torch
43+
import transformers
44+
45+
print(torch.__version__, transformers.__version__)
46+
47+
Those two versions leads to the following list of patches.
48+
49+
.. runpython::
50+
:showcode:
51+
:rst:
52+
53+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
54+
from onnx_diagnostic.torch_export_patches import torch_export_patches
55+
56+
details = PatchDetails()
57+
with torch_export_patches(
58+
patch_transformers=True,
59+
patch_torch=True,
60+
patch_diffusers=True,
61+
patch_details=details,
62+
):
63+
pass
64+
for patch in details.patched:
65+
if patch.function_to_patch == patch.patch:
66+
continue
67+
rst = patch.format_diff(format="rst")
68+
print()
69+
print()
70+
print(rst)
71+
print()
72+
print()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
requires_transformers,
7+
hide_stdout,
8+
has_transformers,
9+
)
10+
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
12+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails, PatchInfo
13+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
14+
15+
16+
class TestPatchDetails(ExtTestCase):
17+
@hide_stdout()
18+
def test_patch_details(self):
19+
details = PatchDetails()
20+
with torch_export_patches(
21+
patch_transformers=True,
22+
verbose=10,
23+
patch_torch=True,
24+
patch_diffusers=True,
25+
patch_details=details,
26+
):
27+
pass
28+
self.assertGreater(details.n_patches, 1)
29+
data = details.data()
30+
self.assertEqual(len(data), details.n_patches)
31+
for patch in details.patched:
32+
_kind, f1, f2 = patch.family, patch.function_to_patch, patch.patch
33+
raw = patch.format_diff(format="raw")
34+
if callable(f1):
35+
self.assertIn(f1.__name__, raw)
36+
self.assertIn(f2.__name__, raw)
37+
rst = patch.format_diff(format="rst")
38+
self.assertIn("====", rst)
39+
40+
# second time to make every patch was removed
41+
with torch_export_patches(
42+
patch_transformers=True,
43+
verbose=10,
44+
patch_torch=True,
45+
patch_diffusers=True,
46+
patch_details=details,
47+
):
48+
pass
49+
50+
@requires_transformers("4.55")
51+
def test_patch_diff(self):
52+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
53+
patched_eager_mask,
54+
)
55+
56+
eager_mask = transformers.masking_utils.eager_mask
57+
self.assertEqual(eager_mask.__name__, "eager_mask")
58+
self.assertEqual(patched_eager_mask.__name__, "patched_eager_mask")
59+
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
60+
self.assertIn("+ # PATCHED:", diff)
61+
62+
def test_involved_patches(self):
63+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
64+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
65+
details = PatchDetails()
66+
with torch_export_patches(
67+
patch_transformers=True, patch_details=details, patch_torch=False
68+
):
69+
ep = torch.export.export(
70+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
71+
)
72+
patches = details.patches_involded_in_graph(ep.graph)
73+
self.assertNotEmpty(patches)
74+
report = details.make_report(patches, format="rst")
75+
if has_transformers("4.51"):
76+
self.assertIn("====", report)
77+
self.assertIn("def longrope_frequency_update", report)
78+
79+
80+
if __name__ == "__main__":
81+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_validate_whole_models3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from onnx_diagnostic.torch_models.validate import validate_model
1010

1111

12-
class TestValidateWholeModels1(ExtTestCase):
12+
class TestValidateWholeModels3(ExtTestCase):
1313
@requires_torch("2.7")
1414
@hide_stdout()
1515
@ignore_warnings(FutureWarning)

0 commit comments

Comments
 (0)