Skip to content

Commit 695814d

Browse files
committed
update titles
1 parent 731a746 commit 695814d

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Enlightening Examples
5252
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
5353
* `Export with DynamicCache and dynamic shapes
5454
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
55-
* `Steel method forward to guess the dynamic shapes
55+
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
5656
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
5757

5858
**Investigate ONNX models**

_doc/examples/plot_export_tiny_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
.. _l-plot-tiny-llm-export:
33
4-
Steel method forward to guess the dynamic shapes
5-
================================================
4+
Steel method forward to guess the dynamic shapes (with Tiny-LLM)
5+
================================================================
66
77
Inputs are always dynamic with LLMs that is why dynamic shapes
88
needs to be specified when a LLM is exported with:func:`torch.export.export`.

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
from onnx_diagnostic.helpers import string_type
2828
from onnx_diagnostic.export import ModelInputs
2929

30+
# %%
31+
# We need addition import in case ``transformers<4.50``.
32+
# Exporting DynamicCache is not supported before that.
33+
from onnx_diagnostic.ext_test_case import has_transformers
34+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
35+
3036

3137
class Model(torch.nn.Module):
3238
def forward(self, x, y):
@@ -201,6 +207,20 @@ def forward(self, cache, z):
201207

202208
# %%
203209
# And finally the export.
204-
205-
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
210+
# The export is simple if ``transformers>=4.50``, otherwise,
211+
# transformers needs to be patched.
212+
# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
213+
# registers functions to serialize ``DynamicCache`` and another class
214+
# called ``patched_DynamicCache``. This one is modified to make
215+
# the shape inference implemented in :epkg:`torch` happy.
216+
217+
if has_transformers("4.50"):
218+
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
219+
else:
220+
with bypass_export_some_errors(
221+
patch_transformers=True, replace_dynamic_cache=True
222+
) as modificator:
223+
ep = torch.export.export(
224+
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
225+
)
206226
print(ep)

0 commit comments

Comments
 (0)