Skip to content

Commit 7eb831c

Browse files
authored
Renames bypass_export_some_patches into torch_export_patches, keep the old name (#75)
* rename bypass_export_some_patches into torch_export_patches, keep the old name * black * fix * doc
1 parent c579f8e commit 7eb831c

28 files changed

+152
-122
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.4.3
55
+++++
66

7+
* :pr:`75`: renames bypass_export_some_patches into torch_export_patches, keeps the old name
8+
* :pr:`74`: increases the list of class/architectures
9+
710
0.4.2
811
+++++
912

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using
3030

3131
.. code-block:: python
3232
33-
with bypass_export_some_errors(patch_transformers=True) as f:
33+
with torch_export_patches(patch_transformers=True) as f:
3434
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
3535
# ...
3636
3737
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
3838
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
39-
`bypass_export_some_errors <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`_.
39+
`torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
4040

4141
Getting started
4242
+++++++++++++++

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ onnx_diagnostic.torch_export_patches
1313
:members:
1414
:no-undoc-members:
1515

16-
.. autofunction:: onnx_diagnostic.torch_export_patches.bypass_export_some_errors
16+
.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches
1717

1818
.. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions

_doc/examples/plot_export_hub_codellama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
get_pretrained_config,
3030
task_from_id,
3131
)
32-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
32+
from onnx_diagnostic.torch_export_patches import torch_export_patches
3333
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3434

3535
model_id = "codellama/CodeLlama-7b-Python-hf"
@@ -90,9 +90,9 @@
9090
#
9191
# The model uses :class:`transformers.cache_utils.DynamicCache`.
9292
# It still requires patches to be exportable (control flow).
93-
# See :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
93+
# See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
9494

95-
with bypass_export_some_errors(patch_transformers=True) as f:
95+
with torch_export_patches(patch_transformers=True) as f:
9696
ep = torch.export.export(
9797
model,
9898
(),

_doc/examples/plot_export_locate_issue.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import traceback
2727
import torch
2828
from onnx_diagnostic import doc
29-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
29+
from onnx_diagnostic.torch_export_patches import torch_export_patches
3030

3131

3232
class ModelWithIssue(torch.nn.Module):
@@ -80,12 +80,12 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
8080
# Stop when a dynamic dimension turns static
8181
# ==========================================
8282
#
83-
# We use :func:`bypass_export_some_errors
84-
# <onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`
83+
# We use :func:`torch_export_patches
84+
# <onnx_diagnostic.torch_export_patches.torch_export_patches>`
8585
# to replace torch implementation by a new one raising the exception
8686
# mentioned in previous section.
8787

88-
with bypass_export_some_errors(stop_if_static=1, verbose=1):
88+
with torch_export_patches(stop_if_static=1, verbose=1):
8989
try:
9090
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
9191
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from onnx_diagnostic import doc
7070
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
7171
from onnx_diagnostic.helpers import string_type
72-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
72+
from onnx_diagnostic.torch_export_patches import torch_export_patches
7373
from onnx_diagnostic.torch_models.llms import get_tiny_llm
7474

7575

@@ -101,10 +101,10 @@
101101

102102
# %%
103103
# If they are not registered, function
104-
# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
104+
# func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
105105
# should take care of it. Then we export.
106106

107-
with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
107+
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
108108
assert is_cache_dynamic_registered() # it must be true here
109109
ep = torch.export.export(
110110
untrained_model,
@@ -126,7 +126,7 @@
126126

127127
cloned_inputs = copy.deepcopy(inputs)
128128

129-
with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
129+
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
130130
ep = torch.export.export(
131131
model,
132132
(),

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
2727
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
2828
from onnx_diagnostic.helpers.rt_helper import make_feeds
29-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
29+
from onnx_diagnostic.torch_export_patches import torch_export_patches
3030
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3131
from onnx_diagnostic.torch_models.hghub import (
3232
get_untrained_model_with_inputs,
@@ -76,7 +76,7 @@
7676
# ++++++
7777

7878

79-
with bypass_export_some_errors(patch_transformers=True) as modificator:
79+
with torch_export_patches(patch_transformers=True) as modificator:
8080

8181
# Unnecessary steps but useful in case of an error
8282
# We check the cache is registered.
@@ -110,7 +110,7 @@
110110
# applies :meth:`torch.export.ExportedProgram.run_decompositions`
111111
# may export local pieces of the model again.
112112

113-
with bypass_export_some_errors(patch_transformers=True):
113+
with torch_export_patches(patch_transformers=True):
114114
epo = torch.onnx.export(
115115
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
116116
)

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
make_dynamic_cache,
3131
)
3232
from onnx_diagnostic.export import ModelInputs
33-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
33+
from onnx_diagnostic.torch_export_patches import torch_export_patches
3434

3535

3636
class Model(torch.nn.Module):
@@ -99,14 +99,14 @@ def forward(self, cache, z):
9999
# And finally the export.
100100
# The export is simple if ``transformers>=4.50``, otherwise,
101101
# transformers needs to be patched.
102-
# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
102+
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
103103
# registers functions to serialize ``DynamicCache``. This one is modified to make
104104
# the shape inference implemented in :epkg:`torch` happy.
105105

106106
if has_transformers("4.50"):
107107
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
108108
else:
109-
with bypass_export_some_errors(patch_transformers=True) as modificator:
109+
with torch_export_patches(patch_transformers=True) as modificator:
110110
ep = torch.export.export(
111111
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
112112
)

_doc/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-d
2424

2525
.. code-block:: python
2626
27-
with bypass_export_some_errors(patch_transformers=True) as f:
27+
with torch_export_patches(patch_transformers=True) as f:
2828
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
2929
# ...
3030
3131
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
32-
:func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`.
32+
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
3333

3434
.. toctree::
3535
:maxdepth: 1

_doc/recipes/plot_dynamic_shapes_python_int.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
import torch
1616
from onnx_diagnostic import doc
17-
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
17+
from onnx_diagnostic.torch_export_patches import torch_export_patches
1818

1919

2020
class Model(torch.nn.Module):
@@ -73,12 +73,12 @@ def forward(self, x):
7373
# Find the error
7474
# ++++++++++++++
7575
#
76-
# Function :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
76+
# Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
7777
# has a parameter ``stop_if_static`` which patches torch to raise exception
7878
# when something like that is happening.
7979

8080

81-
with bypass_export_some_errors(stop_if_static=True):
81+
with torch_export_patches(stop_if_static=True):
8282
ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
8383
print(ep)
8484

0 commit comments

Comments
 (0)