diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst
index b6d36fa3..26ab8af8 100644
--- a/CHANGELOGS.rst
+++ b/CHANGELOGS.rst
@@ -4,6 +4,9 @@ Change Logs
0.4.3
+++++
+* :pr:`75`: renames bypass_export_some_patches into torch_export_patches, keeps the old name
+* :pr:`74`: increases the list of class/architectures
+
0.4.2
+++++
diff --git a/README.rst b/README.rst
index cf240828..e89ded82 100644
--- a/README.rst
+++ b/README.rst
@@ -30,13 +30,13 @@ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using
.. code-block:: python
- with bypass_export_some_errors(patch_transformers=True) as f:
+ with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
See `documentation of onnx-diagnostic `_ and
-`bypass_export_some_errors `_.
+`torch_export_patches `_.
Getting started
+++++++++++++++
diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst
index a536d9dc..b35e8c62 100644
--- a/_doc/api/torch_export_patches/index.rst
+++ b/_doc/api/torch_export_patches/index.rst
@@ -13,6 +13,6 @@ onnx_diagnostic.torch_export_patches
:members:
:no-undoc-members:
-.. autofunction:: onnx_diagnostic.torch_export_patches.bypass_export_some_errors
+.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches
.. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions
diff --git a/_doc/examples/plot_export_hub_codellama.py b/_doc/examples/plot_export_hub_codellama.py
index 533ad0dd..bf932afc 100644
--- a/_doc/examples/plot_export_hub_codellama.py
+++ b/_doc/examples/plot_export_hub_codellama.py
@@ -29,7 +29,7 @@
get_pretrained_config,
task_from_id,
)
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
model_id = "codellama/CodeLlama-7b-Python-hf"
@@ -90,9 +90,9 @@
#
# The model uses :class:`transformers.cache_utils.DynamicCache`.
# It still requires patches to be exportable (control flow).
-# See :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
+# See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
-with bypass_export_some_errors(patch_transformers=True) as f:
+with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(
model,
(),
diff --git a/_doc/examples/plot_export_locate_issue.py b/_doc/examples/plot_export_locate_issue.py
index 3cfbdad1..3c1006aa 100644
--- a/_doc/examples/plot_export_locate_issue.py
+++ b/_doc/examples/plot_export_locate_issue.py
@@ -26,7 +26,7 @@
import traceback
import torch
from onnx_diagnostic import doc
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
class ModelWithIssue(torch.nn.Module):
@@ -80,12 +80,12 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
# Stop when a dynamic dimension turns static
# ==========================================
#
-# We use :func:`bypass_export_some_errors
-# `
+# We use :func:`torch_export_patches
+# `
# to replace torch implementation by a new one raising the exception
# mentioned in previous section.
-with bypass_export_some_errors(stop_if_static=1, verbose=1):
+with torch_export_patches(stop_if_static=1, verbose=1):
try:
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
diff --git a/_doc/examples/plot_export_tiny_llm_patched.py b/_doc/examples/plot_export_tiny_llm_patched.py
index f1576baf..60a20d15 100644
--- a/_doc/examples/plot_export_tiny_llm_patched.py
+++ b/_doc/examples/plot_export_tiny_llm_patched.py
@@ -69,7 +69,7 @@
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.llms import get_tiny_llm
@@ -101,10 +101,10 @@
# %%
# If they are not registered, function
-# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
+# func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
# should take care of it. Then we export.
-with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
+with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
assert is_cache_dynamic_registered() # it must be true here
ep = torch.export.export(
untrained_model,
@@ -126,7 +126,7 @@
cloned_inputs = copy.deepcopy(inputs)
-with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
+with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
ep = torch.export.export(
model,
(),
diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py
index ddcf1480..b91bdcb4 100644
--- a/_doc/examples/plot_export_tiny_phi2.py
+++ b/_doc/examples/plot_export_tiny_phi2.py
@@ -26,7 +26,7 @@
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers.rt_helper import make_feeds
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
@@ -76,7 +76,7 @@
# ++++++
-with bypass_export_some_errors(patch_transformers=True) as modificator:
+with torch_export_patches(patch_transformers=True) as modificator:
# Unnecessary steps but useful in case of an error
# We check the cache is registered.
@@ -110,7 +110,7 @@
# applies :meth:`torch.export.ExportedProgram.run_decompositions`
# may export local pieces of the model again.
-with bypass_export_some_errors(patch_transformers=True):
+with torch_export_patches(patch_transformers=True):
epo = torch.onnx.export(
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
)
diff --git a/_doc/examples/plot_export_with_dynamic_cache.py b/_doc/examples/plot_export_with_dynamic_cache.py
index ed7d9ae7..5dfa9c3f 100644
--- a/_doc/examples/plot_export_with_dynamic_cache.py
+++ b/_doc/examples/plot_export_with_dynamic_cache.py
@@ -30,7 +30,7 @@
make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
class Model(torch.nn.Module):
@@ -99,14 +99,14 @@ def forward(self, cache, z):
# And finally the export.
# The export is simple if ``transformers>=4.50``, otherwise,
# transformers needs to be patched.
-# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
+# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
# registers functions to serialize ``DynamicCache``. This one is modified to make
# the shape inference implemented in :epkg:`torch` happy.
if has_transformers("4.50"):
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
ep = torch.export.export(
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
)
diff --git a/_doc/index.rst b/_doc/index.rst
index 917df6b1..6d4792dd 100644
--- a/_doc/index.rst
+++ b/_doc/index.rst
@@ -24,12 +24,12 @@ Sources available at `github/onnx-diagnostic =4.52")
- with bypass_export_some_errors(patch_transformers=True, verbose=10):
+ with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py
index a99db3f4..731371ef 100644
--- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py
+++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py
@@ -2,7 +2,7 @@
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers, has_torch
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -20,7 +20,7 @@ def test_image_text_to_text(self):
raise unittest.SkipTest("The model has control flow.")
if not has_torch("2.7.99"):
raise unittest.SkipTest("sym_max does not work with dynamic dimension")
- with bypass_export_some_errors(patch_transformers=True, verbose=10):
+ with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
diff --git a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
index b441646a..bbfb34ab 100644
--- a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
+++ b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py
@@ -2,7 +2,7 @@
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -17,7 +17,7 @@ def test_zero_shot_image_classification(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
- with bypass_export_some_errors(patch_transformers=True, verbose=10):
+ with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py
index 32831134..1cdb19eb 100644
--- a/_unittests/ut_torch_export_patches/test_dynamic_class.py
+++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py
@@ -13,7 +13,7 @@
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
- bypass_export_some_errors,
+ torch_export_patches,
)
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
@@ -56,7 +56,7 @@ def forward(self, x, cache):
DYN = torch.export.Dim.DYNAMIC
# patching
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
got = model(*inputs)
self.assertEqualArray(expected, got)
ep = torch.export.export(
@@ -260,7 +260,7 @@ def forward(self, x, dc):
self.assertEqualArray(expected, got)
return
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
ep = torch.export.export(model, (), kwargs=inputs)
args, _spec = torch.utils._pytree.tree_flatten(inputs)
@@ -292,7 +292,7 @@ def test_phi2_export_module(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model,
(),
@@ -330,7 +330,7 @@ def test_phi2_export_interpreter(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model,
(),
diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py
index 699ccdd7..b9adbc7d 100644
--- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py
+++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py
@@ -10,7 +10,7 @@
)
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
- bypass_export_some_errors,
+ torch_export_patches,
)
@@ -34,7 +34,7 @@ def __init__(self):
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
- with bypass_export_some_errors(verbose=1):
+ with torch_export_patches(verbose=1):
values, spec = py_pytree.tree_flatten(cache)
cache2 = py_pytree.tree_unflatten(values, spec)
self.assertEqual(cache.max_batch_size, cache2.max_batch_size)
@@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
model = Model()
model(x, cache)
- with bypass_export_some_errors(verbose=1):
+ with torch_export_patches(verbose=1):
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
torch.export.export(Model(), (x, cache))
@@ -113,7 +113,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
model(x, cache)
DYN = torch.export.Dim.DYNAMIC
- with bypass_export_some_errors():
+ with torch_export_patches():
cache = MambaCache(_config(), max_batch_size=2, device="cpu")
torch.export.export(
Model(),
diff --git a/_unittests/ut_torch_export_patches/test_patch_base_class.py b/_unittests/ut_torch_export_patches/test_patch_base_class.py
index 7dc4c708..6e10e91f 100644
--- a/_unittests/ut_torch_export_patches/test_patch_base_class.py
+++ b/_unittests/ut_torch_export_patches/test_patch_base_class.py
@@ -1,7 +1,10 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import (
+ torch_export_patches,
+ bypass_export_some_errors,
+)
class TestPatchBaseClass(ExtTestCase):
@@ -73,6 +76,28 @@ def m1(self, x):
model = Model()
x = torch.arange(4)
self.assertEqualArray(x * x, model(x))
+ with torch_export_patches(custom_patches=[patched_Model], verbose=10):
+ self.assertEqualArray(x**3, model(x))
+
+ @hide_stdout()
+ def test_bypass_export_some_errors(self):
+ class Model2(torch.nn.Module):
+ def m2(self, x):
+ return x * x
+
+ def forward(self, x):
+ return self.m2(x)
+
+ class patched_Model:
+ _PATCHED_CLASS_ = Model2
+ _PATCHES_ = ["m2"]
+
+ def m2(self, x):
+ return x**3
+
+ model = Model2()
+ x = torch.arange(4)
+ self.assertEqualArray(x * x, model(x))
with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10):
self.assertEqualArray(x**3, model(x))
diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py
index 619bd76d..1e0bcb12 100644
--- a/_unittests/ut_torch_export_patches/test_patch_serialization.py
+++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py
@@ -9,7 +9,7 @@
flatten_unflatten_for_dynamic_shapes,
)
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
- bypass_export_some_errors,
+ torch_export_patches,
)
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy
@@ -21,7 +21,7 @@ def test_encoder_decoder_cache_flatten(self):
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
)
- with bypass_export_some_errors():
+ with torch_export_patches():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]",
@@ -39,7 +39,7 @@ def test_encoder_decoder_cache_deepcopy(self):
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
)
- with bypass_export_some_errors():
+ with torch_export_patches():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)
@@ -65,13 +65,13 @@ def forward(self, cache):
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
]
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
@ignore_warnings(UserWarning)
def test_dynamic_cache_flatten(self):
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
- with bypass_export_some_errors():
+ with torch_export_patches():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#2[T1s4x4x4,T1s4x4x4]",
@@ -97,13 +97,13 @@ def forward(self, cache):
DYN = torch.export.Dim.DYNAMIC
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
- with bypass_export_some_errors():
+ with torch_export_patches():
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
@ignore_warnings(UserWarning)
def test_dynamic_cache_deepcopy(self):
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
- with bypass_export_some_errors():
+ with torch_export_patches():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)
@@ -111,7 +111,7 @@ def test_dynamic_cache_deepcopy(self):
def test_base_model_output_deepcopy(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
self.assertEqual(bo.__class__.__name__, "BaseModelOutput")
- with bypass_export_some_errors():
+ with torch_export_patches():
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput")
@@ -120,7 +120,7 @@ def test_base_model_output_deepcopy(self):
@ignore_warnings(UserWarning)
def test_base_model_output_string_type(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
- with bypass_export_some_errors():
+ with torch_export_patches():
self.assertEqual(
"BaseModelOutput(last_hidden_state:T1s4x4x4)",
self.string_type(bo, with_shape=True),
@@ -129,7 +129,7 @@ def test_base_model_output_string_type(self):
@ignore_warnings(UserWarning)
def test_base_model_output_flatten(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
- with bypass_export_some_errors():
+ with torch_export_patches():
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#1[T1s4x4x4]",
@@ -153,13 +153,13 @@ def forward(self, cache):
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]
- with bypass_export_some_errors():
+ with torch_export_patches():
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
@ignore_warnings(UserWarning)
def test_base_model_output_unflatten_flatten(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, dict)
@@ -170,7 +170,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
cache = make_sliding_window_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
)
- with bypass_export_some_errors():
+ with torch_export_patches():
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)
@@ -192,7 +192,7 @@ def forward(self, cache):
DYN = torch.export.Dim.DYNAMIC
ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
@ignore_warnings(UserWarning)
@@ -200,7 +200,7 @@ def test_sliding_window_cache_flatten(self):
cache = make_sliding_window_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
)
- with bypass_export_some_errors():
+ with torch_export_patches():
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#2[T1s4x4x4x4,T1s4x4x4x4]",
diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py
index 6083812f..7b2b49e1 100644
--- a/_unittests/ut_torch_models/test_hghub_model.py
+++ b/_unittests/ut_torch_models/test_hghub_model.py
@@ -10,7 +10,7 @@
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
class TestHuggingFaceHubModel(ExtTestCase):
@@ -91,7 +91,7 @@ def test_get_untrained_model_with_inputs_clip_vit(self):
mid = "openai/clip-vit-base-patch16"
data = get_untrained_model_with_inputs(mid, verbose=1)
model, inputs = data["model"], data["inputs"]
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
model(**inputs)
# different expected value for different version of transformers
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
diff --git a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py
index 4ec3e943..35b3c224 100644
--- a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py
+++ b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py
@@ -6,7 +6,7 @@
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_models.llms import get_phi2
from onnx_diagnostic.helpers import string_type
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_DynamicCache,
)
@@ -22,7 +22,7 @@ def test_export_tiny_llm_2_bypassed(self):
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
- with bypass_export_some_errors(
+ with torch_export_patches(
patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10
) as modificator:
@@ -61,11 +61,11 @@ def test_export_phi2_2_bypassed(self):
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
assert ep
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
assert ep
diff --git a/_unittests/ut_torch_models/test_tiny_llms_onnx.py b/_unittests/ut_torch_models/test_tiny_llms_onnx.py
index eda2d3de..48c13f06 100644
--- a/_unittests/ut_torch_models/test_tiny_llms_onnx.py
+++ b/_unittests/ut_torch_models/test_tiny_llms_onnx.py
@@ -10,7 +10,7 @@
requires_transformers,
)
from onnx_diagnostic.torch_models.llms import get_tiny_llm
-from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
+from onnx_diagnostic.torch_export_patches import torch_export_patches
try:
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
@@ -58,7 +58,7 @@ def test_onnx_export_tiny_llm_xdbg(self):
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
- with bypass_export_some_errors(patch_transformers=True):
+ with torch_export_patches(patch_transformers=True):
onx = to_onnx(
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], verbose=1
)
@@ -74,7 +74,7 @@ def test_bypass_onnx_export_tiny_llm_official_nopositionids(self):
del inputs["position_ids"]
del ds["position_ids"]
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
- with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
+ with torch_export_patches(patch_transformers=True, verbose=1) as modificator:
new_inputs = modificator(copy.deepcopy(inputs))
ep = torch.onnx.export(
model,
@@ -101,7 +101,7 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
- with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
+ with torch_export_patches(patch_transformers=True, verbose=1) as modificator:
new_inputs = modificator(copy.deepcopy(inputs))
ep = torch.onnx.export(
model,
@@ -134,7 +134,7 @@ def test_bypass_onnx_export_tiny_llm_xdbg(self):
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
- with bypass_export_some_errors(patch_transformers=True, verbose=2) as modificator:
+ with torch_export_patches(patch_transformers=True, verbose=2) as modificator:
new_inputs = modificator(inputs)
onx = to_onnx(
model,
diff --git a/onnx_diagnostic/torch_export_patches/__init__.py b/onnx_diagnostic/torch_export_patches/__init__.py
index ff978ee3..0b241ba0 100644
--- a/onnx_diagnostic/torch_export_patches/__init__.py
+++ b/onnx_diagnostic/torch_export_patches/__init__.py
@@ -1,4 +1,8 @@
from .onnx_export_errors import (
- bypass_export_some_errors,
+ torch_export_patches,
register_additional_serialization_functions,
)
+
+
+# bypass_export_some_errors is the first name given to the patches.
+bypass_export_some_errors = torch_export_patches # type: ignore
diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py
index 54b3e875..27cbfe14 100644
--- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py
+++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py
@@ -93,7 +93,7 @@ def register_additional_serialization_functions(
@contextlib.contextmanager
-def bypass_export_some_errors(
+def torch_export_patches(
patch_sympy: bool = True,
patch_torch: bool = True,
patch_transformers: bool = False,
@@ -145,13 +145,13 @@ def bypass_export_some_errors(
::
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
onx = to_onnx(..., inputs, ...)
::
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
onx = torch.onnx.export(..., inputs, ...)
@@ -159,7 +159,7 @@ def bypass_export_some_errors(
::
- with bypass_export_some_errors(patch_transformers=True) as modificator:
+ with torch_export_patches(patch_transformers=True) as modificator:
inputs = modificator(inputs)
ep = torch.export.export(..., inputs, ...)
@@ -190,7 +190,7 @@ def bypass_export_some_errors(
if verbose:
print(
- "[bypass_export_some_errors] replace torch.jit.isinstance, "
+ "[torch_export_patches] replace torch.jit.isinstance, "
"torch._dynamo.mark_static_address"
)
@@ -210,8 +210,8 @@ def bypass_export_some_errors(
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
if verbose:
- print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}")
- print("[bypass_export_some_errors] patch sympy")
+ print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
+ print("[torch_export_patches] patch sympy")
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
@@ -228,9 +228,9 @@ def bypass_export_some_errors(
)
if verbose:
- print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}")
- print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}")
- print("[bypass_export_some_errors] patch pytorch")
+ print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
+ print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
+ print("[torch_export_patches] patch pytorch")
# torch.jit.isinstance
f_jit_isinstance = torch.jit.isinstance
@@ -252,7 +252,7 @@ def bypass_export_some_errors(
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
if catch_constraints:
if verbose:
- print("[bypass_export_some_errors] modifies shape constraints")
+ print("[torch_export_patches] modifies shape constraints")
f_produce_guards_and_solve_constraints = (
torch._export.non_strict_utils.produce_guards_and_solve_constraints
)
@@ -277,22 +277,20 @@ def bypass_export_some_errors(
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
if verbose:
- print(
- "[bypass_export_some_errors] assert when a dynamic dimension turns static"
- )
- print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
+ print("[torch_export_patches] assert when a dynamic dimension turns static")
+ print("[torch_export_patches] replaces ShapeEnv._set_replacement")
f_shape_env__set_replacement = ShapeEnv._set_replacement
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
if verbose:
- print("[bypass_export_some_errors] replaces ShapeEnv._log_guard")
+ print("[torch_export_patches] replaces ShapeEnv._log_guard")
f_shape_env__log_guard = ShapeEnv._log_guard
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
if stop_if_static > 1:
if verbose:
- print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
+ print("[torch_export_patches] replaces ShapeEnv._check_frozen")
f_shape_env__check_frozen = ShapeEnv._check_frozen
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
@@ -305,7 +303,7 @@ def bypass_export_some_errors(
import transformers
print(
- f"[bypass_export_some_errors] transformers.__version__="
+ f"[torch_export_patches] transformers.__version__="
f"{transformers.__version__!r}"
)
revert_patches_info = patch_module_or_classes(
@@ -314,7 +312,7 @@ def bypass_export_some_errors(
if custom_patches:
if verbose:
- print("[bypass_export_some_errors] applies custom patches")
+ print("[torch_export_patches] applies custom patches")
revert_custom_patches_info = patch_module_or_classes(
custom_patches, verbose=verbose
)
@@ -326,7 +324,7 @@ def bypass_export_some_errors(
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
if verbose:
- print("[bypass_export_some_errors] done patching")
+ print("[torch_export_patches] done patching")
try:
yield fct_callable
@@ -336,7 +334,7 @@ def bypass_export_some_errors(
#######
if verbose:
- print("[bypass_export_some_errors] remove patches")
+ print("[torch_export_patches] remove patches")
if patch_sympy:
# tracked by https://github.com/pytorch/pytorch/issues/143494
@@ -346,7 +344,7 @@ def bypass_export_some_errors(
delattr(sympy.core.numbers.IntegerConstant, "name")
if verbose:
- print("[bypass_export_some_errors] restored sympy functions")
+ print("[torch_export_patches] restored sympy functions")
#######
# torch
@@ -362,22 +360,22 @@ def bypass_export_some_errors(
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
if verbose:
- print("[bypass_export_some_errors] restored pytorch functions")
+ print("[torch_export_patches] restored pytorch functions")
if stop_if_static:
if verbose:
- print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
+ print("[torch_export_patches] restored ShapeEnv._set_replacement")
ShapeEnv._set_replacement = f_shape_env__set_replacement
if verbose:
- print("[bypass_export_some_errors] restored ShapeEnv._log_guard")
+ print("[torch_export_patches] restored ShapeEnv._log_guard")
ShapeEnv._log_guard = f_shape_env__log_guard
if stop_if_static > 1:
if verbose:
- print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
+ print("[torch_export_patches] restored ShapeEnv._check_frozen")
ShapeEnv._check_frozen = f_shape_env__check_frozen
if catch_constraints:
@@ -389,11 +387,11 @@ def bypass_export_some_errors(
f__check_input_constraints_for_graph
)
if verbose:
- print("[bypass_export_some_errors] restored shape constraints")
+ print("[torch_export_patches] restored shape constraints")
if custom_patches:
if verbose:
- print("[bypass_export_some_errors] unpatch custom patches")
+ print("[torch_export_patches] unpatch custom patches")
unpatch_module_or_classes(
custom_patches, revert_custom_patches_info, verbose=verbose
)
@@ -404,7 +402,7 @@ def bypass_export_some_errors(
if patch_transformers:
if verbose:
- print("[bypass_export_some_errors] unpatch transformers")
+ print("[torch_export_patches] unpatch transformers")
unpatch_module_or_classes(
patch_transformers_list, revert_patches_info, verbose=verbose
)
diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py
index a39266c2..aa93ece7 100644
--- a/onnx_diagnostic/torch_models/test_helper.py
+++ b/onnx_diagnostic/torch_models/test_helper.py
@@ -12,7 +12,7 @@
from ..helpers.torch_test_helper import to_any, torch_deepcopy
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from ..tasks import random_input_kwargs
-from ..torch_export_patches import bypass_export_some_errors
+from ..torch_export_patches import torch_export_patches
from ..torch_export_patches.patch_inputs import use_dyn_not_str
from .hghub import get_untrained_model_with_inputs
@@ -242,9 +242,9 @@ def validate_model(
depend on the the exporter
:param quiet: if quiet, catches exception if any issue
:param patch: applies patches (``patch_transformers=True``) before exporting,
- see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
:param stop_if_static: stops if a dynamic dimension becomes static,
- see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
:param dump_folder: dumps everything in a subfolder of this one
:param drop_inputs: drops this list of inputs (given their names)
:param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
@@ -417,7 +417,7 @@ def validate_model(
f"[validate_model] applies patches before exporting "
f"stop_if_static={stop_if_static}"
)
- with bypass_export_some_errors( # type: ignore
+ with torch_export_patches( # type: ignore
patch_transformers=True,
stop_if_static=stop_if_static,
verbose=max(0, verbose - 1),