Skip to content

Commit 3bc484c

Browse files
committed
examples checking about dynamic_shapes
1 parent 93b5939 commit 3bc484c

File tree

5 files changed

+681
-0
lines changed

5 files changed

+681
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
.. _l-plot-tiny-llm-export-dim01:
3+
4+
Export with dynamic dimensions in ``{0,1}``
5+
===========================================
6+
7+
The first version of :func:`torch.export.export` did not support
8+
any tensor with a dimension equal to 0, 1 if the dimension was expected
9+
to be dynamic. The latest versions offers more options. Let's check it works.
10+
The experiments consists in exporting the model with different sets of inputs
11+
and checking the exported models works with all set of inputs.
12+
13+
Availabe input sets
14+
+++++++++++++++++++
15+
16+
"""
17+
18+
import itertools
19+
from tqdm import tqdm
20+
import numpy as np
21+
import pandas
22+
import torch
23+
from onnx_diagnostic import doc
24+
from onnx_diagnostic.helpers import max_diff, string_type
25+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
26+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
27+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
28+
from onnx_diagnostic.torch_export_patches import (
29+
torch_export_patches,
30+
register_additional_serialization_functions,
31+
)
32+
33+
34+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
35+
model, dynamic_shapes = data["model"], data["dynamic_shapes"]
36+
37+
input_sets = {k: v for k, v in data.items() if k.startswith("inputs")}
38+
39+
for k, v in input_sets.items():
40+
print(f"{k:20}: {string_type(v, with_shape=True)}")
41+
42+
# %%
43+
# The dynamic shapes are:
44+
45+
print(f"dynamic_shapes: {string_type(dynamic_shapes)}")
46+
47+
# %% The exporter does not support strings.
48+
49+
dynamic_shapes = use_dyn_not_str(dynamic_shapes)
50+
print(f"dynamic_shapes: {string_type(dynamic_shapes)}")
51+
52+
# %%
53+
# Let's check they all work and compute the expected values.
54+
# We use deepcopy because caches are usually modified inplace.
55+
56+
expected = {}
57+
for k, v in input_sets.items():
58+
expected[k] = model(**torch_deepcopy(v))
59+
print(f"{k:20}: {string_type(expected[k], with_shape=True)}")
60+
61+
# %%
62+
# Export with options
63+
# +++++++++++++++++++
64+
#
65+
# We try to export with the following options:
66+
# - cache registration: register cache serialization with
67+
# :func:`onnx_diagnostic.torch_export_patches.register_additional_serialization_functions`
68+
# - oblivious: an option to remove some the exception raises by the exporter
69+
# - rt: see ``prefer_deferred_runtime_asserts_over_guards`` in :func:`torch.export.export`
70+
# - cache_patch: patches the model before exporting with
71+
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
72+
#
73+
# Some function first.
74+
75+
76+
def export_model(
77+
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
78+
):
79+
if cache and not cache_patch:
80+
with register_additional_serialization_functions(patch_transformers=True):
81+
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
82+
if cache_patch:
83+
with torch_export_patches(patch_transformers=True):
84+
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
85+
if oblivious:
86+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
87+
return export_model(model, dynamic_shapes, inputs, rt=rt)
88+
return torch.export.export(
89+
model,
90+
(),
91+
inputs,
92+
dynamic_shapes=dynamic_shapes,
93+
prefer_deferred_runtime_asserts_over_guards=rt,
94+
)
95+
96+
97+
def try_export_model(
98+
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
99+
):
100+
try:
101+
return export_model(
102+
model,
103+
dynamic_shapes,
104+
inputs,
105+
cache=cache,
106+
oblivious=oblivious,
107+
rt=rt,
108+
cache_patch=cache_patch,
109+
)
110+
except Exception as e:
111+
return e
112+
113+
114+
def validation(ep, input_sets, expected):
115+
mod = ep.module()
116+
for k, v in input_sets.items():
117+
try:
118+
got = mod(**torch_deepcopy(v))
119+
except Exception as e:
120+
yield k, e
121+
continue
122+
yield k, max_diff(expected[k], got, verbose=0)
123+
124+
125+
# %%
126+
# The main loop.
127+
128+
results = []
129+
130+
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
131+
with tqdm(list(itertools.product(*possibilities))) as pbar:
132+
for cache, cache_patch, oblivious, rt, inputs in pbar:
133+
if cache_patch and not cache:
134+
# patches include caches.
135+
continue
136+
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
137+
legend = "-".join(k for k, v in kwargs.items() if v)
138+
legend = f"{legend}/{inputs}"
139+
pbar.set_description(f"{legend} EXPORT")
140+
141+
# export
142+
ep = try_export_model(
143+
model, dynamic_shapes, torch_deepcopy(input_sets[inputs]), **kwargs
144+
)
145+
if isinstance(ep, Exception):
146+
obs = {
147+
**kwargs,
148+
"export_with": inputs,
149+
"EXPORT": 0,
150+
"ERR-EXPORT": str(ep).split("\n")[0],
151+
}
152+
results.append(obs)
153+
continue
154+
155+
pbar.set_description(f"{legend} VALIDATE")
156+
common = {**kwargs, "export_with": inputs, "EXPORT": 1}
157+
for inp, res in validation(ep, input_sets, expected):
158+
if isinstance(res, Exception):
159+
obs = {
160+
**common,
161+
"run_with": inp,
162+
"ERR-RUN": str(res).split("\n")[0],
163+
"WORKS": 0,
164+
}
165+
else:
166+
obs = {
167+
**common,
168+
"run_with": inp,
169+
"WORKS": int(~np.isnan(res["abs"]) and res["abs"] < 1e-3),
170+
}
171+
results.append(obs)
172+
173+
# %%
174+
# Let's save the results.
175+
176+
df = pandas.DataFrame(results)
177+
df.to_excel("plot_export_tiny_llm_dim01.xlsx")
178+
df
179+
180+
# %% The export failures.
181+
182+
no_export = df[df.EXPORT == 0]
183+
no_export.to_excel("plot_export_tiny_llm_dim01.no_export.xlsx")
184+
no_export
185+
186+
# %%
187+
# The validation failures.
188+
189+
invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
190+
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
191+
columns=["run_with"],
192+
values=["WORKS", "ERR-RUN"],
193+
)
194+
invalid.to_excel("plot_export_tiny_llm_dim01.invalid.xlsx")
195+
invalid
196+
197+
# %% Successes.
198+
199+
success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
200+
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
201+
columns=["run_with"],
202+
values=["WORKS"],
203+
)
204+
success.to_excel("plot_export_tiny_llm_dim01.success.xlsx")
205+
success
206+
207+
208+
# %%
209+
# If you have any error, then look at example
210+
# :ref:`l-plot-tiny-llm-export-cache_patched`.
211+
212+
doc.plot_legend("Tiny-LLM\nexport with\ndimension in {0,1}", "torch.export.export", "tomato")

0 commit comments

Comments
 (0)