Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Change Logs
0.4.3
+++++

* :pr:`75`: renames bypass_export_some_patches into torch_export_patches, keeps the old name
* :pr:`74`: increases the list of class/architectures

0.4.2
+++++

Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using

.. code-block:: python

with bypass_export_some_errors(patch_transformers=True) as f:
with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
`bypass_export_some_errors <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`_.
`torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.

Getting started
+++++++++++++++
Expand Down
2 changes: 1 addition & 1 deletion _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ onnx_diagnostic.torch_export_patches
:members:
:no-undoc-members:

.. autofunction:: onnx_diagnostic.torch_export_patches.bypass_export_some_errors
.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches

.. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions
6 changes: 3 additions & 3 deletions _doc/examples/plot_export_hub_codellama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_pretrained_config,
task_from_id,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

model_id = "codellama/CodeLlama-7b-Python-hf"
Expand Down Expand Up @@ -90,9 +90,9 @@
#
# The model uses :class:`transformers.cache_utils.DynamicCache`.
# It still requires patches to be exportable (control flow).
# See :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
# See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`

with bypass_export_some_errors(patch_transformers=True) as f:
with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(
model,
(),
Expand Down
8 changes: 4 additions & 4 deletions _doc/examples/plot_export_locate_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class ModelWithIssue(torch.nn.Module):
Expand Down Expand Up @@ -80,12 +80,12 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
# Stop when a dynamic dimension turns static
# ==========================================
#
# We use :func:`bypass_export_some_errors
# <onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`
# We use :func:`torch_export_patches
# <onnx_diagnostic.torch_export_patches.torch_export_patches>`
# to replace torch implementation by a new one raising the exception
# mentioned in previous section.

with bypass_export_some_errors(stop_if_static=1, verbose=1):
with torch_export_patches(stop_if_static=1, verbose=1):
try:
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
Expand Down
8 changes: 4 additions & 4 deletions _doc/examples/plot_export_tiny_llm_patched.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.llms import get_tiny_llm


Expand Down Expand Up @@ -101,10 +101,10 @@

# %%
# If they are not registered, function
# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
# func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
# should take care of it. Then we export.

with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
assert is_cache_dynamic_registered() # it must be true here
ep = torch.export.export(
untrained_model,
Expand All @@ -126,7 +126,7 @@

cloned_inputs = copy.deepcopy(inputs)

with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
ep = torch.export.export(
model,
(),
Expand Down
6 changes: 3 additions & 3 deletions _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
Expand Down Expand Up @@ -76,7 +76,7 @@
# ++++++


with bypass_export_some_errors(patch_transformers=True) as modificator:
with torch_export_patches(patch_transformers=True) as modificator:

# Unnecessary steps but useful in case of an error
# We check the cache is registered.
Expand Down Expand Up @@ -110,7 +110,7 @@
# applies :meth:`torch.export.ExportedProgram.run_decompositions`
# may export local pieces of the model again.

with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
epo = torch.onnx.export(
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
)
Expand Down
6 changes: 3 additions & 3 deletions _doc/examples/plot_export_with_dynamic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
Expand Down Expand Up @@ -99,14 +99,14 @@ def forward(self, cache, z):
# And finally the export.
# The export is simple if ``transformers>=4.50``, otherwise,
# transformers needs to be patched.
# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
# registers functions to serialize ``DynamicCache``. This one is modified to make
# the shape inference implemented in :epkg:`torch` happy.

if has_transformers("4.50"):
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
with bypass_export_some_errors(patch_transformers=True) as modificator:
with torch_export_patches(patch_transformers=True) as modificator:
ep = torch.export.export(
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
)
Expand Down
4 changes: 2 additions & 2 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-d

.. code-block:: python

with bypass_export_some_errors(patch_transformers=True) as f:
with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
:func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`.
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.

.. toctree::
:maxdepth: 1
Expand Down
6 changes: 3 additions & 3 deletions _doc/recipes/plot_dynamic_shapes_python_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import math
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
Expand Down Expand Up @@ -73,12 +73,12 @@ def forward(self, x):
# Find the error
# ++++++++++++++
#
# Function :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
# Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
# has a parameter ``stop_if_static`` which patches torch to raise exception
# when something like that is happening.


with bypass_export_some_errors(stop_if_static=True):
with torch_export_patches(stop_if_static=True):
ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
print(ep)

Expand Down
6 changes: 3 additions & 3 deletions _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class TestDynamicShapes(ExtTestCase):
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_couple_input_ds_cache(self):

kwargs = {"A": T3x4, "B": (T3x1, cache)}
Cls = CoupleInputsDynamicShapes
with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
self.assertEqual(
None,
Cls(
Expand Down Expand Up @@ -749,7 +749,7 @@ def test_couple_input_ds_change_dynamic_dimensions_dynamic_cache(self):
{"A": make_dynamic_cache([(torch.ones((2, 2, 2, 2)), torch.ones((2, 2, 2, 2)))])},
{"A": [[{0: "batch", 2: "last"}], [{0: "batch", 2: "last"}]]},
)
with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
new_inputs = inst.change_dynamic_dimensions()
self.assertIsInstance(new_inputs["A"], transformers.cache_utils.DynamicCache)
self.assertEqual((3, 2, 3, 2), new_inputs["A"].key_cache[0].shape)
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_export/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
flatten_unflatten_for_dynamic_shapes,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class TestSerialization(ExtTestCase):
Expand Down Expand Up @@ -49,7 +49,7 @@ def forward(self, cache):
return cache.key_cache[0]

cache = self._get_cache()
with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
flat_unflat = flatten_unflatten_for_dynamic_shapes(cache)
s = string_type(flat_unflat, with_shape=True)
self.assertEqual("#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]", s)
Expand Down
8 changes: 4 additions & 4 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from onnx_diagnostic.torch_export_patches.patch_inputs import (
convert_dynamic_axes_into_dynamic_shapes,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches


class TestCacheHelpers(ExtTestCase):
Expand Down Expand Up @@ -67,14 +67,14 @@ def test_replace_by(self):
)
self.assertEqual(dynamic_shapes, nds)

with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
res = cpl.replace_string_by()
dsc = res["past_key_values"]
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)

def test_unflatten_flatten_dynamic_cache(self):
with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
c1 = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
self.assertIsInstance(c1, transformers.cache_utils.DynamicCache)
unflat = flatten_unflatten_for_dynamic_shapes(c1)
Expand All @@ -87,7 +87,7 @@ def test_unflatten_flatten_dynamic_cache(self):
)

def test_unflatten_flatten_encoder_decoder_cache(self):
with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
c2 = make_encoder_decoder_cache(
make_dynamic_cache(
[
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_helpers/test_ort_session_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
InferenceSessionForNumpy,
InferenceSessionForTorch,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
from onnx_diagnostic.helpers.onnx_helper import np_dtype_to_tensor_dtype
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_check_allruntimes_on_tiny_llm(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
expected = model(**copy.deepcopy(inputs))

with bypass_export_some_errors(patch_transformers=True):
with torch_export_patches(patch_transformers=True):
ep = torch.onnx.export(
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True
)
Expand Down
20 changes: 10 additions & 10 deletions _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


Expand All @@ -17,7 +17,7 @@ def test_text2text_generation(self):
raise unittest.SkipTest(f"not working for {mid!r}")
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand All @@ -31,7 +31,7 @@ def test_text_generation(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_automatic_speech_recognition(self):
"#1[T1r3]",
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
)
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
Expand All @@ -96,7 +96,7 @@ def test_automatic_speech_recognition(self):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
Expand All @@ -117,7 +117,7 @@ def test_fill_mask(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand All @@ -131,7 +131,7 @@ def test_feature_extraction(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand All @@ -145,7 +145,7 @@ def test_text_classification(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand All @@ -159,7 +159,7 @@ def test_sentence_similary(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with bypass_export_some_errors(patch_transformers=True, verbose=10):
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand All @@ -175,7 +175,7 @@ def test_falcon_mamba_dev(self):
self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)])
if not has_transformers("4.55"):
raise unittest.SkipTest("The model has control flow.")
with bypass_export_some_errors(patch_transformers=True, verbose=10, stop_if_static=1):
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand Down
Loading
Loading