Skip to content

Commit 93b00d9

Browse files
committed
changes
1 parent c44be6a commit 93b00d9

File tree

9 files changed

+278
-160
lines changed

9 files changed

+278
-160
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`38`, uses the registered serialization functions when it is available
78
* :pr:`30`, :pr:`31`: adds command to test a model id, validate the export
89
* :pr:`29`: adds helpers to measure the memory peak and run benchmark
910
on different processes
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
.. _l-plot-export-with-args-kwargs:
3+
4+
==================================
5+
Dynamic Shapes for *args, **kwargs
6+
==================================
7+
8+
Quick tour of dynamic shapes.
9+
10+
Simple Examples
11+
===============
12+
13+
We first look at examples playing positional and names parameters
14+
to understand how :func:`torch.export.export` works.
15+
16+
args
17+
++++
18+
"""
19+
20+
import pprint
21+
import torch
22+
from onnx_diagnostic import doc
23+
from onnx_diagnostic.export import ModelInputs
24+
25+
26+
class Model(torch.nn.Module):
27+
def forward(self, x, y):
28+
return x + y
29+
30+
31+
model = Model()
32+
x = torch.randn((5, 6))
33+
y = torch.randn((1, 6))
34+
model(x, y) # to check it works
35+
36+
ep = torch.export.export(model, (x, y))
37+
print(ep)
38+
39+
# %%
40+
# As expected there is no dynamic shapes.
41+
# We use :class:`onnx_diagnostic.export.ModelInputs`
42+
# to define them from two set of valid inputs.
43+
# These inputs must have different value for the dynamic
44+
# dimensions.
45+
46+
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
47+
mi = ModelInputs(Model(), inputs)
48+
ds = mi.guess_dynamic_shapes()
49+
pprint.pprint(ds)
50+
51+
# %%
52+
# The function returns a tuple with two objects.
53+
# The first one for the positional arguments, the other one
54+
# for the named arguments. There is no named arguments. We
55+
# we used the first result to export.
56+
57+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
58+
print(ep)
59+
60+
# %%
61+
# kwargs
62+
# ++++++
63+
#
64+
# We do the same with named arguments.
65+
66+
67+
class Model(torch.nn.Module):
68+
def forward(self, x, y):
69+
return x + y
70+
71+
72+
model = Model()
73+
x = torch.randn((5, 6))
74+
y = torch.randn((1, 6))
75+
model(x=x, y=y) # to check it works
76+
77+
# %%
78+
# Two sets of valid inputs.
79+
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
80+
mi = ModelInputs(Model(), inputs)
81+
ds = mi.guess_dynamic_shapes()
82+
pprint.pprint(ds)
83+
84+
# %%
85+
# And we export.
86+
ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
87+
print(ep)
88+
89+
# %%
90+
# args and kwargs
91+
# +++++++++++++++
92+
#
93+
# :func:`torch.export.export` does not like having dynami shapes
94+
# for both args and kwargs. We need to define them using one mechanism.
95+
96+
97+
class Model(torch.nn.Module):
98+
def forward(self, x, y):
99+
return x + y
100+
101+
102+
model = Model()
103+
x = torch.randn((5, 6))
104+
y = torch.randn((1, 6))
105+
model(x, y=y) # to check it works
106+
107+
# %%
108+
# Two sets of valid inputs with positional and names arguments.
109+
110+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
111+
mi = ModelInputs(Model(), inputs)
112+
ds = mi.guess_dynamic_shapes()
113+
pprint.pprint(ds)
114+
115+
# %%
116+
# This does not work with :func:`torch.export.export` so
117+
# we use a method to move the positional dynamic shapes to
118+
# named one. The method relies on the signature of the
119+
# forward method.
120+
121+
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
122+
pprint.pprint(new_ds)
123+
124+
# %%
125+
# And we export.
126+
127+
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
128+
print(ep)
129+
130+
# %%
131+
132+
doc.plot_legend("dynamic shapes\n*args, **kwargs", "torch.export.export", "tomato")

_doc/examples/plot_export_with_dynamic_shapes_auto.py renamed to _doc/examples/plot_export_with_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
.. _l-plot-sxport-with-dynamio-shapes-auto:
2+
.. _l-plot-sxport-with-auto:
33
44
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
55
====================================================================
@@ -94,4 +94,4 @@ def forward(self, x, y, z):
9494

9595
# %%
9696

97-
doc.plot_legend("dynamic shapes\ninferred", "torch.export.export", "tomato")
97+
doc.plot_legend("torch.export.Dim\nor DYNAMIC\nor AUTO", "torch.export.export", "tomato")

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 40 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -11,144 +11,28 @@
1111
The example shows a tool which determines the dynamic shapes
1212
for :func:`torch.export.export` based on a set of valid inputs.
1313
14-
Simple Examples
15-
===============
14+
DynamicCache
15+
============
1616
17-
We first look at examples playing positional and names parameters
18-
to understand how :func:`torch.export.export` works.
19-
20-
args
21-
++++
17+
:func:`torch.export.export` serializes caches and any custom class
18+
if these serialization functions are provided with is the case for
19+
:class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
20+
The dynamic shapes must be provided following the serialized form.
2221
"""
2322

2423
import pprint
2524
import torch
2625
from onnx_diagnostic import doc
27-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
26+
from onnx_diagnostic.ext_test_case import has_transformers
2827
from onnx_diagnostic.helpers import string_type
28+
from onnx_diagnostic.helpers.cache_helper import (
29+
flatten_unflatten_for_dynamic_shapes,
30+
make_dynamic_cache,
31+
)
2932
from onnx_diagnostic.export import ModelInputs
30-
31-
# %%
32-
# We need addition import in case ``transformers<4.50``.
33-
# Exporting DynamicCache is not supported before that.
34-
from onnx_diagnostic.ext_test_case import has_transformers
3533
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
3634

3735

38-
class Model(torch.nn.Module):
39-
def forward(self, x, y):
40-
return x + y
41-
42-
43-
model = Model()
44-
x = torch.randn((5, 6))
45-
y = torch.randn((1, 6))
46-
model(x, y) # to check it works
47-
48-
ep = torch.export.export(model, (x, y))
49-
print(ep)
50-
51-
# %%
52-
# As expected there is no dynamic shapes.
53-
# We use :class:`onnx_diagnostic.export.ModelInputs`
54-
# to define them from two set of valid inputs.
55-
# These inputs must have different value for the dynamic
56-
# dimensions.
57-
58-
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
59-
mi = ModelInputs(Model(), inputs)
60-
ds = mi.guess_dynamic_shapes()
61-
pprint.pprint(ds)
62-
63-
# %%
64-
# The function returns a tuple with two objects.
65-
# The first one for the positional arguments, the other one
66-
# for the named arguments. There is no named arguments. We
67-
# we used the first result to export.
68-
69-
ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
70-
print(ep)
71-
72-
# %%
73-
# kwargs
74-
# ++++++
75-
#
76-
# We do the same with named arguments.
77-
78-
79-
class Model(torch.nn.Module):
80-
def forward(self, x, y):
81-
return x + y
82-
83-
84-
model = Model()
85-
x = torch.randn((5, 6))
86-
y = torch.randn((1, 6))
87-
model(x=x, y=y) # to check it works
88-
89-
# %%
90-
# Two sets of valid inputs.
91-
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
92-
mi = ModelInputs(Model(), inputs)
93-
ds = mi.guess_dynamic_shapes()
94-
pprint.pprint(ds)
95-
96-
# %%
97-
# And we export.
98-
ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
99-
print(ep)
100-
101-
# %%
102-
# args and kwargs
103-
# +++++++++++++++
104-
#
105-
# :func:`torch.export.export` does not like having dynami shapes
106-
# for both args and kwargs. We need to define them using one mechanism.
107-
108-
109-
class Model(torch.nn.Module):
110-
def forward(self, x, y):
111-
return x + y
112-
113-
114-
model = Model()
115-
x = torch.randn((5, 6))
116-
y = torch.randn((1, 6))
117-
model(x, y=y) # to check it works
118-
119-
# %%
120-
# Two sets of valid inputs with positional and names arguments.
121-
122-
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
123-
mi = ModelInputs(Model(), inputs)
124-
ds = mi.guess_dynamic_shapes()
125-
pprint.pprint(ds)
126-
127-
# %%
128-
# This does not work with :func:`torch.export.export` so
129-
# we use a method to move the positional dynamic shapes to
130-
# named one. The method relies on the signature of the
131-
# forward method.
132-
133-
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
134-
pprint.pprint(new_ds)
135-
136-
# %%
137-
# And we export.
138-
139-
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
140-
print(ep)
141-
142-
# %%
143-
# DynamicCache
144-
# ============
145-
#
146-
# :func:`torch.export.export` serializes caches and any custom class
147-
# if these serialization functions are provided with is the case for
148-
# :class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
149-
# The dynamic shapes must be provided following the serialized form.
150-
151-
15236
class Model(torch.nn.Module):
15337
def forward(self, cache, z):
15438
return (
@@ -196,14 +80,19 @@ def forward(self, cache, z):
19680
]
19781

19882
# %%
199-
# And the first set of inputs looks like:
200-
print(string_type(inputs[0], with_shape=True))
83+
# And the second set of inputs looks like:
84+
print(string_type(inputs[1], with_shape=True))
20185

20286
# %%
203-
# We can now compute the dynamic shapes.
87+
# Guess the dynamic shapes
88+
# ========================
89+
#
90+
# The following tool can be used to guess the dynamic shapes
91+
# the way :func:`torch.export.export` expects them.
20492

20593
mi = ModelInputs(Model(), inputs)
20694
ds = mi.guess_dynamic_shapes()
95+
20796
pprint.pprint(ds)
20897

20998
# %%
@@ -223,6 +112,26 @@ def forward(self, cache, z):
223112
)
224113
print(ep)
225114

115+
# Do we need to guess?
116+
# ++++++++++++++++++++
117+
#
118+
# Function :func:`onnx_diagnostic.helpers.string_type` is using
119+
# the serialization functions to print out the DynamicCache the was
120+
# :func:`torch.export.export` expects them.
121+
122+
print(string_type(cache, with_shape=True))
123+
226124
# %%
125+
# You can also use function
126+
# :func:`onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes`
127+
# to show a DynamicCache restructured the way :func:`torch.export.export` expects
128+
# it to be without the custom class.
129+
130+
print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
131+
132+
# %%
133+
# This code works for any custom class if it was registered
134+
# with :func:`torch.utils._pytree.register_pytree_node`.
135+
227136

228-
doc.plot_legend("dynamic shapes\nfor cache", "torch.export.export", "tomato")
137+
doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")

_doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Enlightening Examples
6060
**Torch Export**
6161

6262
* :ref:`l-plot-export-cond`
63-
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
63+
* :ref:`l-plot-sxport-with-auto`
6464
* :ref:`l-plot-export-with-dynamic-shape`
6565
* :ref:`l-plot-export-locale-issue`
6666
* :ref:`l-plot-tiny-llm-export`

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ def forward(self, x, y):
4040
ds = mi.guess_dynamic_shapes()
4141
self.assertEqual(ds, ((), {}))
4242

43+
def test_guess_dynamic_shapes_auto(self):
44+
class Model(torch.nn.Module):
45+
def forward(self, x, y):
46+
return x + y
47+
48+
model = Model()
49+
x = torch.randn((5, 6))
50+
y = torch.randn((1, 6))
51+
model(x, y)
52+
self.assertNotEmpty(y)
53+
54+
mi = ModelInputs(Model(), [])
55+
ds = mi.guess_dynamic_shapes(auto=True)
56+
self.assertEqual(ds, ((), {}))
57+
4358
def test_guess_dynamic_shapes_1args(self):
4459
class Model(torch.nn.Module):
4560
def forward(self, x, y):

0 commit comments

Comments
 (0)