Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,17 @@ jobs:
- name: run tests
run: |
pip install pytest
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py
PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --ignore _unittests/ut_torch_models/test_tiny_llms_bypassed.py --ignore _unittests/ut_tasks/test_tasks_zero_shot_image_classification.py --ignore _unittests/ut_tasks/test_tasks_image_classification.py --ignore _unittests/ut_torch_models/test_validate_whole_models*.py

- name: test models
run: |
echo "----"
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models1.py
echo "----"
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models2.py
echo "----"
PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_validate_whole_models3.py
echo "----"

# - name: run backend tests python
# run: PYTHONPATH=. UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.0
+++++

* :pr:`282`: add tools to understand better which functions were patched
* :pr:`280`: fixes patches for sdpa_attention_forward for different version of transformers
* :pr:`278`: implements ``onnx_generate_with_genai``
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
Expand Down
1 change: 1 addition & 0 deletions _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ onnx_diagnostic.torch_export_patches
onnx_export_errors
onnx_export_serialization
patches/index
patch_details
patch_expressions
patch_inputs
patch_module
Expand Down
6 changes: 6 additions & 0 deletions _doc/api/torch_export_patches/patch_details.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

onnx_diagnostic.torch_export_patches.patch_details
==================================================

.. automodule:: onnx_diagnostic.torch_export_patches.patch_details
:members:
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def linkcode_resolve(domain, info):
epkg_dictionary = {
"aten functions": "https://pytorch.org/cppdocs/api/namespace_at.html#functions",
"azure pipeline": "https://azure.microsoft.com/en-us/products/devops/pipelines",
"black": "https://github.com/psf/black",
"Custom Backends": "https://docs.pytorch.org/docs/stable/torch.compiler_custom_backends.html",
"diffusers": "https://github.com/huggingface/diffusers",
"DOT": "https://graphviz.org/doc/info/lang.html",
Expand Down
1 change: 1 addition & 0 deletions _doc/status/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ what works and what does not with :func:`torch.export.export`.
exported_program_dynamic
exporter_dynamic
patches_coverage
patches_diff

Examples checking about dynamic dimensions:

Expand Down
72 changes: 72 additions & 0 deletions _doc/status/patches_diff.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
.. _l-patch-diff:

============
Patches Diff
============

Patches are not always needed to export a LLM.
Most of the time, only serialization function are needed to export
a LLM with cache (``DynamicCache``, ...).
Function :func:`register_additional_serialization_functions
<onnx_diagnostic.torch_export_patches.register_additional_serialization_functions>`
is enough in many cases.

.. code-block:: python

import torch
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions

with register_additional_serialization_functions(patch_transformers=True):
ep = torch.export.export(...)

Function :func:`torch_export_patches
<onnx_diagnostic.torch_export_patches.torch_export_patches>`
helps fixing some issues for many models.

.. code-block:: python

import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches

with torch_export_patches(patch_transformers=True):
ep = torch.export.export(...)

Class :class:`PatchDetails <onnx_diagnostic.torch_export_patches.patch_details.PatchDetails>`
gives an example on how to retrieve the list of involded patches for a specific model.
Those patches belongs to the following list which depends on transformers and
pytorch versions.

.. runpython::
:showcode:

import torch
import transformers

print(torch.__version__, transformers.__version__)

Those two versions leads to the following list of patches.

.. runpython::
:showcode:
:rst:

from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_export_patches import torch_export_patches

details = PatchDetails()
with torch_export_patches(
patch_transformers=True,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass
for patch in details.patched:
if patch.function_to_patch == patch.patch:
continue
rst = patch.format_diff(format="rst")
print()
print()
print(rst)
print()
print()
81 changes: 81 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
import torch
import transformers
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
requires_transformers,
hide_stdout,
has_transformers,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails, PatchInfo
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs


class TestPatchDetails(ExtTestCase):
@hide_stdout()
def test_patch_details(self):
details = PatchDetails()
with torch_export_patches(
patch_transformers=True,
verbose=10,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass
self.assertGreater(details.n_patches, 1)
data = details.data()
self.assertEqual(len(data), details.n_patches)
for patch in details.patched:
_kind, f1, f2 = patch.family, patch.function_to_patch, patch.patch
raw = patch.format_diff(format="raw")
if callable(f1):
self.assertIn(f1.__name__, raw)
self.assertIn(f2.__name__, raw)
rst = patch.format_diff(format="rst")
self.assertIn("====", rst)

# second time to make every patch was removed
with torch_export_patches(
patch_transformers=True,
verbose=10,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass

@requires_transformers("4.55")
def test_patch_diff(self):
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_eager_mask,
)

eager_mask = transformers.masking_utils.eager_mask
self.assertEqual(eager_mask.__name__, "eager_mask")
self.assertEqual(patched_eager_mask.__name__, "patched_eager_mask")
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
self.assertIn("+ # PATCHED:", diff)

def test_involved_patches(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
details = PatchDetails()
with torch_export_patches(
patch_transformers=True, patch_details=details, patch_torch=False
):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
patches = details.patches_involded_in_graph(ep.graph)
self.assertNotEmpty(patches)
report = details.make_report(patches, format="rst")
if has_transformers("4.51"):
self.assertIn("====", report)
self.assertIn("def longrope_frequency_update", report)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_validate_whole_models3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from onnx_diagnostic.torch_models.validate import validate_model


class TestValidateWholeModels1(ExtTestCase):
class TestValidateWholeModels3(ExtTestCase):
@requires_torch("2.7")
@hide_stdout()
@ignore_warnings(FutureWarning)
Expand Down
Loading
Loading