Skip to content

Commit 6d6dd3e

Browse files
committed
cache
1 parent 2ed97a5 commit 6d6dd3e

File tree

4 files changed

+212
-3
lines changed

4 files changed

+212
-3
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ or
4545
4646
**Enlightening Examples**
4747
48-
* `Use DYNAMIC or AUTO when dynamic shapes has constraints
48+
* `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
4949
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
50+
* `Export with DynamicCache and dynamic shapes
51+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
5052
* `Steel method forward to guess the dynamic shapes
5153
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
5254
* `Running ReferenceEvaluator on a failing model
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
.. _l-plot-export-with-dynamic-shape:
3+
4+
===========================================
5+
Export with DynamicCache and dynamic shapes
6+
===========================================
7+
8+
Every LLMs implemented in :epkg:`trasnformers` use cache.
9+
One of the most used is :class:`transformers.cache_utils.DynamicCache`.
10+
The cache size is dynamic to cope with the growing context.
11+
The example shows a tool which determines the dynamic shapes
12+
for :func:`torch.export.export` based on a set of valid inputs.
13+
14+
Simple Examples
15+
===============
16+
17+
We first look at examples playing positional and names parameters
18+
to understand how :func:`torch.export.export` works.
19+
20+
args
21+
++++
22+
"""
23+
24+
import pprint
25+
import torch
26+
from onnx_diagnostic.cache_helpers import make_dynamic_cache
27+
from onnx_diagnostic.helpers import string_type
28+
from onnx_diagnostic.export import ModelInputs
29+
30+
31+
class Model(torch.nn.Module):
32+
def forward(self, x, y):
33+
return x + y
34+
35+
36+
model = Model()
37+
x = torch.randn((5, 6))
38+
y = torch.randn((1, 6))
39+
model(x, y) # to check it works
40+
41+
ep = torch.export.export(model, (x, y))
42+
print(ep)
43+
44+
# %%
45+
# As expected there is no dynamic shapes.
46+
# We use :class:`onnx_diagnostic.export.ModelInputs`
47+
# to define them from two set of valid inputs.
48+
# These inputs must have different value for the dynamic
49+
# dimensions.
50+
51+
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
52+
mi = ModelInputs(Model(), inputs)
53+
ds = mi.guess_dynamic_shapes()
54+
pprint.pprint(ds)
55+
56+
# %%
57+
# The function returns a tuple with two objets.
58+
# The first one for the positional arguments, the other one
59+
# for the named arguments. There is no named argements. We
60+
# we used the first result to export.
61+
62+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
63+
print(ep)
64+
65+
# %%
66+
# kwargs
67+
# ++++++
68+
#
69+
# We do the same with named argments.
70+
71+
72+
class Model(torch.nn.Module):
73+
def forward(self, x, y):
74+
return x + y
75+
76+
77+
model = Model()
78+
x = torch.randn((5, 6))
79+
y = torch.randn((1, 6))
80+
model(x=x, y=y) # to check it works
81+
82+
# %%
83+
# Two sets of valid inputs.
84+
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
85+
mi = ModelInputs(Model(), inputs)
86+
ds = mi.guess_dynamic_shapes()
87+
pprint.pprint(ds)
88+
89+
# %%
90+
# And we export.
91+
ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
92+
print(ep)
93+
94+
# %%
95+
# args and kwargs
96+
# +++++++++++++++
97+
#
98+
# :func:`torch.export.export` does not like having dynami shapes
99+
# for both args and kwargs. We need to define them using one mechanism.
100+
101+
102+
class Model(torch.nn.Module):
103+
def forward(self, x, y):
104+
return x + y
105+
106+
107+
model = Model()
108+
x = torch.randn((5, 6))
109+
y = torch.randn((1, 6))
110+
model(x, y=y) # to check it works
111+
112+
# %%
113+
# Two sets of valid inputs with positional and names arguments.
114+
115+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
116+
mi = ModelInputs(Model(), inputs)
117+
ds = mi.guess_dynamic_shapes()
118+
pprint.pprint(ds)
119+
120+
# %%
121+
# This does not work with :func:`torch.export.export` so
122+
# we use a method to move the positional dynamic shapes to
123+
# named one. The method relies on the signature of the
124+
# forward method.
125+
126+
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
127+
pprint.pprint(new_ds)
128+
129+
# %%
130+
# And we export.
131+
132+
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
133+
print(ep)
134+
135+
# %%
136+
# DynamicCache
137+
# ============
138+
#
139+
# :func:`torch.export.export` serializes caches and any custom class
140+
# if these serialization functions are provided with is the case for
141+
# :class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
142+
# The dynamic shapes must be provided following the serialized form.
143+
144+
145+
class Model(torch.nn.Module):
146+
def forward(self, cache, z):
147+
return (
148+
z
149+
+ cache.key_cache[0]
150+
+ cache.key_cache[1]
151+
+ cache.value_cache[0]
152+
+ cache.value_cache[1]
153+
)
154+
155+
156+
model = Model()
157+
158+
n_layers = 2
159+
bsize, nheads, slen, dim = 2, 4, 3, 7
160+
cache = make_dynamic_cache(
161+
[
162+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
163+
for i in range(n_layers)
164+
]
165+
)
166+
z = torch.randn((1, 1, 1, 7))
167+
model(cache, z) # to check it works.
168+
169+
# %%
170+
# The cache looks like this:
171+
172+
print(string_type(cache, with_shape=True))
173+
174+
175+
# %% Let's create another set of inputs.
176+
177+
cache2 = make_dynamic_cache(
178+
[
179+
(
180+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
181+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
182+
)
183+
for i in range(n_layers)
184+
]
185+
)
186+
inputs = [
187+
(cache, z),
188+
(cache2, torch.randn((1, 1, 1, 8))),
189+
]
190+
191+
# %%
192+
# And the first set of inputs looks like:
193+
print(string_type(inputs[0], with_shape=True))
194+
195+
# %%
196+
# We can now compute the dynamic shapes.
197+
198+
mi = ModelInputs(Model(), inputs)
199+
ds = mi.guess_dynamic_shapes()
200+
pprint.pprint(ds)
201+
202+
# %%
203+
# And finally the export.
204+
205+
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
206+
print(ep)

_doc/examples/plot_export_with_dynamic_shapes_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
.. _l-plot-sxport-with-dynamio-shapes-auto:
33
4-
Use DYNAMIC or AUTO when dynamic shapes has constraints
5-
=======================================================
4+
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
5+
====================================================================
66
77
Settings the dynamic shapes is not always easy.
88
Here are a few tricks to make it work.

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Source are `sdpython/onnx-diagnostic
4242

4343
* :ref:`l-plot-export-cond`
4444
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
45+
* :ref:`l-plot-export-with-dynamic-shape`
4546
* :ref:`l-plot-tiny-llm-export`
4647
* :ref:`l-plot-failing-reference-evaluator`
4748
* :ref:`l-plot-failing-onnxruntime-evaluator`

0 commit comments

Comments
 (0)