1111The example shows a tool which determines the dynamic shapes
1212for :func:`torch.export.export` based on a set of valid inputs.
1313
14- Simple Examples
15- ===============
14+ DynamicCache
15+ ============
1616
17- We first look at examples playing positional and names parameters
18- to understand how :func:`torch.export.export` works.
19-
20- args
21- ++++
17+ :func:`torch.export.export` serializes caches and any custom class
18+ if these serialization functions are provided with is the case for
19+ :class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
20+ The dynamic shapes must be provided following the serialized form.
2221"""
2322
2423import pprint
2524import torch
2625from onnx_diagnostic import doc
27- from onnx_diagnostic .helpers . cache_helper import make_dynamic_cache
26+ from onnx_diagnostic .ext_test_case import has_transformers
2827from onnx_diagnostic .helpers import string_type
28+ from onnx_diagnostic .helpers .cache_helper import (
29+ flatten_unflatten_for_dynamic_shapes ,
30+ make_dynamic_cache ,
31+ )
2932from onnx_diagnostic .export import ModelInputs
30-
31- # %%
32- # We need addition import in case ``transformers<4.50``.
33- # Exporting DynamicCache is not supported before that.
34- from onnx_diagnostic .ext_test_case import has_transformers
3533from onnx_diagnostic .torch_export_patches import bypass_export_some_errors
3634
3735
38- class Model (torch .nn .Module ):
39- def forward (self , x , y ):
40- return x + y
41-
42-
43- model = Model ()
44- x = torch .randn ((5 , 6 ))
45- y = torch .randn ((1 , 6 ))
46- model (x , y ) # to check it works
47-
48- ep = torch .export .export (model , (x , y ))
49- print (ep )
50-
51- # %%
52- # As expected there is no dynamic shapes.
53- # We use :class:`onnx_diagnostic.export.ModelInputs`
54- # to define them from two set of valid inputs.
55- # These inputs must have different value for the dynamic
56- # dimensions.
57-
58- inputs = [(x , y ), (torch .randn ((7 , 8 )), torch .randn ((1 , 8 )))]
59- mi = ModelInputs (Model (), inputs )
60- ds = mi .guess_dynamic_shapes ()
61- pprint .pprint (ds )
62-
63- # %%
64- # The function returns a tuple with two objects.
65- # The first one for the positional arguments, the other one
66- # for the named arguments. There is no named arguments. We
67- # we used the first result to export.
68-
69- ep = torch .export .export (model , (x , y ), dynamic_shapes = ds [0 ])
70- print (ep )
71-
72- # %%
73- # kwargs
74- # ++++++
75- #
76- # We do the same with named arguments.
77-
78-
79- class Model (torch .nn .Module ):
80- def forward (self , x , y ):
81- return x + y
82-
83-
84- model = Model ()
85- x = torch .randn ((5 , 6 ))
86- y = torch .randn ((1 , 6 ))
87- model (x = x , y = y ) # to check it works
88-
89- # %%
90- # Two sets of valid inputs.
91- inputs = [dict (x = x , y = y ), dict (x = torch .randn ((7 , 8 )), y = torch .randn ((1 , 8 )))]
92- mi = ModelInputs (Model (), inputs )
93- ds = mi .guess_dynamic_shapes ()
94- pprint .pprint (ds )
95-
96- # %%
97- # And we export.
98- ep = torch .export .export (model , (), kwargs = dict (x = x , y = y ), dynamic_shapes = ds [1 ])
99- print (ep )
100-
101- # %%
102- # args and kwargs
103- # +++++++++++++++
104- #
105- # :func:`torch.export.export` does not like having dynami shapes
106- # for both args and kwargs. We need to define them using one mechanism.
107-
108-
109- class Model (torch .nn .Module ):
110- def forward (self , x , y ):
111- return x + y
112-
113-
114- model = Model ()
115- x = torch .randn ((5 , 6 ))
116- y = torch .randn ((1 , 6 ))
117- model (x , y = y ) # to check it works
118-
119- # %%
120- # Two sets of valid inputs with positional and names arguments.
121-
122- inputs = [((x ,), dict (y = y )), ((torch .randn ((7 , 8 )),), dict (y = torch .randn ((1 , 8 ))))]
123- mi = ModelInputs (Model (), inputs )
124- ds = mi .guess_dynamic_shapes ()
125- pprint .pprint (ds )
126-
127- # %%
128- # This does not work with :func:`torch.export.export` so
129- # we use a method to move the positional dynamic shapes to
130- # named one. The method relies on the signature of the
131- # forward method.
132-
133- new_args , new_kwargs , new_ds = mi .move_to_kwargs (* mi .inputs [0 ], ds )
134- pprint .pprint (new_ds )
135-
136- # %%
137- # And we export.
138-
139- ep = torch .export .export (model , new_args , kwargs = new_kwargs , dynamic_shapes = new_ds [1 ])
140- print (ep )
141-
142- # %%
143- # DynamicCache
144- # ============
145- #
146- # :func:`torch.export.export` serializes caches and any custom class
147- # if these serialization functions are provided with is the case for
148- # :class:`transformers.cache_utils.DynamicCache` and ``transformers>=4.50``.
149- # The dynamic shapes must be provided following the serialized form.
150-
151-
15236class Model (torch .nn .Module ):
15337 def forward (self , cache , z ):
15438 return (
@@ -196,14 +80,19 @@ def forward(self, cache, z):
19680]
19781
19882# %%
199- # And the first set of inputs looks like:
200- print (string_type (inputs [0 ], with_shape = True ))
83+ # And the second set of inputs looks like:
84+ print (string_type (inputs [1 ], with_shape = True ))
20185
20286# %%
203- # We can now compute the dynamic shapes.
87+ # Guess the dynamic shapes
88+ # ========================
89+ #
90+ # The following tool can be used to guess the dynamic shapes
91+ # the way :func:`torch.export.export` expects them.
20492
20593mi = ModelInputs (Model (), inputs )
20694ds = mi .guess_dynamic_shapes ()
95+
20796pprint .pprint (ds )
20897
20998# %%
@@ -223,6 +112,26 @@ def forward(self, cache, z):
223112 )
224113print (ep )
225114
115+ # Do we need to guess?
116+ # ++++++++++++++++++++
117+ #
118+ # Function :func:`onnx_diagnostic.helpers.string_type` is using
119+ # the serialization functions to print out the DynamicCache the was
120+ # :func:`torch.export.export` expects them.
121+
122+ print (string_type (cache , with_shape = True ))
123+
226124# %%
125+ # You can also use function
126+ # :func:`onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes`
127+ # to show a DynamicCache restructured the way :func:`torch.export.export` expects
128+ # it to be without the custom class.
129+
130+ print (string_type (flatten_unflatten_for_dynamic_shapes (cache ), with_shape = True ))
131+
132+ # %%
133+ # This code works for any custom class if it was registered
134+ # with :func:`torch.utils._pytree.register_pytree_node`.
135+
227136
228- doc .plot_legend ("dynamic shapes\n for cache " , "torch.export.export" , "tomato" )
137+ doc .plot_legend ("dynamic shapes\n for DynamicCache " , "torch.export.export" , "tomato" )
0 commit comments