Skip to content

Commit 3352388

Browse files
authored
Add method guess_dynamic_shapes (#11)
* first step * stat * add cache * cache * any * CHANGELOGS.rst
1 parent bea8445 commit 3352388

File tree

13 files changed

+1188
-8
lines changed

13 files changed

+1188
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`11`: adds ``ModelInputs`` to guess dynamic shapes
78
* :pr:`9`: adds ``OnnxruntimeEvaluator``
89
* :pr:`8`: adds ``ExtendedReferenceEvaluator``
910
* :pr:`7`: improves function ``investigate_onnxruntime_issue``

README.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ onnx-diagnostic: investigate onnx models
2525
.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/branch/main/graph/badge.svg?token=Wb9ZGDta8J
2626
:target: https://codecov.io/gh/sdpython/onnx-diagnostic
2727

28+
Helps investigating onnx models, exporting modes into onnx.
29+
See :epkg:`documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_.
30+
2831
Getting started
2932
+++++++++++++++
3033
@@ -42,8 +45,10 @@ or
4245
4346
**Enlightening Examples**
4447
45-
* `Use DYNAMIC or AUTO when dynamic shapes has constraints
48+
* `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
4649
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
50+
* `Export with DynamicCache and dynamic shapes
51+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
4752
* `Steel method forward to guess the dynamic shapes
4853
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
4954
* `Running ReferenceEvaluator on a failing model
@@ -95,8 +100,3 @@ Snapshot of usefuls tools
95100
**max_diff**
96101
97102
Returns the maximum discrancies accross nested containers containing tensors.
98-
99-
Documentation
100-
+++++++++++++
101-
102-
See `onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_.

_doc/api/export/dynamic_shapes.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.export.dynamic_shapes
3+
=====================================
4+
5+
.. automodule:: onnx_diagnostic.export.dynamic_shapes
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: onnx_diagnostic.export.dynamic_shapes

_doc/api/export/index.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
onnx_diagnostic.export
2+
======================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
:caption: modules
7+
8+
dynamic_shapes
9+
10+
ModelInputs
11+
+++++++++++
12+
13+
.. autoclass:: onnx_diagnostic.dyanmic_shapes.ModelInputs
14+
:members:
15+
16+
Other functions
17+
+++++++++++++++
18+
19+
.. automodule:: onnx_diagnostic.export
20+
:members:
21+
:no-undoc-members:

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ API of onnx_diagnostic
77
:maxdepth: 1
88
:caption: submodules
99

10+
export/index
1011
reference/index
1112
torch_export_patches/index
1213
torch_models/index
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
.. _l-plot-export-with-dynamic-shape:
3+
4+
===========================================
5+
Export with DynamicCache and dynamic shapes
6+
===========================================
7+
8+
Every LLMs implemented in :epkg:`trasnformers` use cache.
9+
One of the most used is :class:`transformers.cache_utils.DynamicCache`.
10+
The cache size is dynamic to cope with the growing context.
11+
The example shows a tool which determines the dynamic shapes
12+
for :func:`torch.export.export` based on a set of valid inputs.
13+
14+
Simple Examples
15+
===============
16+
17+
We first look at examples playing positional and names parameters
18+
to understand how :func:`torch.export.export` works.
19+
20+
args
21+
++++
22+
"""
23+
24+
import pprint
25+
import torch
26+
from onnx_diagnostic.cache_helpers import make_dynamic_cache
27+
from onnx_diagnostic.helpers import string_type
28+
from onnx_diagnostic.export import ModelInputs
29+
30+
31+
class Model(torch.nn.Module):
32+
def forward(self, x, y):
33+
return x + y
34+
35+
36+
model = Model()
37+
x = torch.randn((5, 6))
38+
y = torch.randn((1, 6))
39+
model(x, y) # to check it works
40+
41+
ep = torch.export.export(model, (x, y))
42+
print(ep)
43+
44+
# %%
45+
# As expected there is no dynamic shapes.
46+
# We use :class:`onnx_diagnostic.export.ModelInputs`
47+
# to define them from two set of valid inputs.
48+
# These inputs must have different value for the dynamic
49+
# dimensions.
50+
51+
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
52+
mi = ModelInputs(Model(), inputs)
53+
ds = mi.guess_dynamic_shapes()
54+
pprint.pprint(ds)
55+
56+
# %%
57+
# The function returns a tuple with two objets.
58+
# The first one for the positional arguments, the other one
59+
# for the named arguments. There is no named argements. We
60+
# we used the first result to export.
61+
62+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
63+
print(ep)
64+
65+
# %%
66+
# kwargs
67+
# ++++++
68+
#
69+
# We do the same with named argments.
70+
71+
72+
class Model(torch.nn.Module):
73+
def forward(self, x, y):
74+
return x + y
75+
76+
77+
model = Model()
78+
x = torch.randn((5, 6))
79+
y = torch.randn((1, 6))
80+
model(x=x, y=y) # to check it works
81+
82+
# %%
83+
# Two sets of valid inputs.
84+
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
85+
mi = ModelInputs(Model(), inputs)
86+
ds = mi.guess_dynamic_shapes()
87+
pprint.pprint(ds)
88+
89+
# %%
90+
# And we export.
91+
ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
92+
print(ep)
93+
94+
# %%
95+
# args and kwargs
96+
# +++++++++++++++
97+
#
98+
# :func:`torch.export.export` does not like having dynami shapes
99+
# for both args and kwargs. We need to define them using one mechanism.
100+
101+
102+
class Model(torch.nn.Module):
103+
def forward(self, x, y):
104+
return x + y
105+
106+
107+
model = Model()
108+
x = torch.randn((5, 6))
109+
y = torch.randn((1, 6))
110+
model(x, y=y) # to check it works
111+
112+
# %%
113+
# Two sets of valid inputs with positional and names arguments.
114+
115+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
116+
mi = ModelInputs(Model(), inputs)
117+
ds = mi.guess_dynamic_shapes()
118+
pprint.pprint(ds)
119+
120+
# %%
121+
# This does not work with :func:`torch.export.export` so
122+
# we use a method to move the positional dynamic shapes to
123+
# named one. The method relies on the signature of the
124+
# forward method.
125+
126+
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
127+
pprint.pprint(new_ds)
128+
129+
# %%
130+
# And we export.
131+
132+
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
133+
print(ep)
134+
135+
# %%
136+
# DynamicCache
137+
# ============
138+
#
139+
# :func:`torch.export.export` serializes caches and any custom class
140+
# if these serialization functions are provided with is the case for
141+
# :class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
142+
# The dynamic shapes must be provided following the serialized form.
143+
144+
145+
class Model(torch.nn.Module):
146+
def forward(self, cache, z):
147+
return (
148+
z
149+
+ cache.key_cache[0]
150+
+ cache.key_cache[1]
151+
+ cache.value_cache[0]
152+
+ cache.value_cache[1]
153+
)
154+
155+
156+
model = Model()
157+
158+
n_layers = 2
159+
bsize, nheads, slen, dim = 2, 4, 3, 7
160+
cache = make_dynamic_cache(
161+
[
162+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
163+
for i in range(n_layers)
164+
]
165+
)
166+
z = torch.randn((1, 1, 1, 7))
167+
model(cache, z) # to check it works.
168+
169+
# %%
170+
# The cache looks like this:
171+
172+
print(string_type(cache, with_shape=True))
173+
174+
175+
# %% Let's create another set of inputs.
176+
177+
cache2 = make_dynamic_cache(
178+
[
179+
(
180+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
181+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
182+
)
183+
for i in range(n_layers)
184+
]
185+
)
186+
inputs = [
187+
(cache, z),
188+
(cache2, torch.randn((1, 1, 1, 8))),
189+
]
190+
191+
# %%
192+
# And the first set of inputs looks like:
193+
print(string_type(inputs[0], with_shape=True))
194+
195+
# %%
196+
# We can now compute the dynamic shapes.
197+
198+
mi = ModelInputs(Model(), inputs)
199+
ds = mi.guess_dynamic_shapes()
200+
pprint.pprint(ds)
201+
202+
# %%
203+
# And finally the export.
204+
205+
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
206+
print(ep)

_doc/examples/plot_export_with_dynamic_shapes_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
.. _l-plot-sxport-with-dynamio-shapes-auto:
33
4-
Use DYNAMIC or AUTO when dynamic shapes has constraints
5-
=======================================================
4+
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
5+
====================================================================
66
77
Settings the dynamic shapes is not always easy.
88
Here are a few tricks to make it work.

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Source are `sdpython/onnx-diagnostic
4242

4343
* :ref:`l-plot-export-cond`
4444
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
45+
* :ref:`l-plot-export-with-dynamic-shape`
4546
* :ref:`l-plot-tiny-llm-export`
4647
* :ref:`l-plot-failing-reference-evaluator`
4748
* :ref:`l-plot-failing-onnxruntime-evaluator`

0 commit comments

Comments
 (0)