Skip to content

Commit 651a43e

Browse files
authored
adds examples checking about dynamic_shapes (#254)
* examples checking about dynamic_shapes * add more comments * spell * adds more options
1 parent 93b5939 commit 651a43e

File tree

6 files changed

+736
-1
lines changed

6 files changed

+736
-1
lines changed

_doc/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,12 @@ def linkcode_resolve(domain, info):
146146
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
147147
("py:class", "transformers.cache_utils.HybridCache"),
148148
("py:class", "transformers.cache_utils.MambaCache"),
149-
("py:class", "transformers.models.mamba.modeling_mamba.MambaCache"),
150149
("py:class", "transformers.cache_utils.SlidingWindowCache"),
151150
("py:class", "transformers.cache_utils.StaticCache"),
152151
("py:class", "transformers.configuration_utils.PretrainedConfig"),
152+
("py:class", "transformers.configuration_utils.PreTrainedConfig"),
153153
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
154+
("py:class", "transformers.models.mamba.modeling_mamba.MambaCache"),
154155
("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"),
155156
("py:func", "torch.export._draft_export.draft_export"),
156157
("py:func", "torch._export.tools.report_exportability"),
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
.. _l-plot-tiny-llm-export-dim01:
3+
4+
Export with dynamic dimensions in ``{0,1}``
5+
===========================================
6+
7+
The first version of :func:`torch.export.export` did not support
8+
any tensor with a dimension equal to 0, 1 if the dimension was expected
9+
to be dynamic. The latest versions offers more options. Let's check it works.
10+
The experiments consists in exporting the model with different sets of inputs
11+
and checking the exported models works with all set of inputs.
12+
13+
Available input sets
14+
++++++++++++++++++++
15+
16+
"""
17+
18+
import itertools
19+
from tqdm import tqdm
20+
import numpy as np
21+
import pandas
22+
import torch
23+
from onnx_diagnostic import doc
24+
from onnx_diagnostic.helpers import max_diff, string_type
25+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
26+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
27+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
28+
from onnx_diagnostic.torch_export_patches import (
29+
torch_export_patches,
30+
register_additional_serialization_functions,
31+
)
32+
33+
34+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
35+
model, dynamic_shapes = data["model"], data["dynamic_shapes"]
36+
37+
# %%
38+
# The trained model can be obtained with:
39+
#
40+
# .. code-block:: python
41+
#
42+
# MODEL_NAME = "arnir0/Tiny-LLM"
43+
# tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
44+
# model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
45+
46+
input_sets = {k: v for k, v in data.items() if k.startswith("inputs")}
47+
48+
for k, v in input_sets.items():
49+
print(f"{k:20}: {string_type(v, with_shape=True)}")
50+
51+
# %%
52+
# The dynamic shapes are:
53+
54+
print(f"dynamic_shapes: {string_type(dynamic_shapes)}")
55+
56+
# %% The exporter does not support strings.
57+
58+
dynamic_shapes = use_dyn_not_str(dynamic_shapes)
59+
print(f"dynamic_shapes: {string_type(dynamic_shapes)}")
60+
61+
# %%
62+
# Let's check they all work and compute the expected values.
63+
# We use deepcopy because caches are usually modified inplace.
64+
65+
expected = {}
66+
for k, v in input_sets.items():
67+
expected[k] = model(**torch_deepcopy(v))
68+
print(f"{k:20}: {string_type(expected[k], with_shape=True)}")
69+
70+
# %%
71+
# Export with options
72+
# +++++++++++++++++++
73+
#
74+
# We try to export with the following options:
75+
# - cache registration: register cache serialization with
76+
# :func:`onnx_diagnostic.torch_export_patches.register_additional_serialization_functions`
77+
# - oblivious: an option to remove some the exception raises by the exporter
78+
# - rt: see ``prefer_deferred_runtime_asserts_over_guards`` in :func:`torch.export.export`
79+
# - cache_patch: patches the model before exporting with
80+
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
81+
#
82+
# Some function first.
83+
84+
85+
def export_model(
86+
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
87+
):
88+
if cache and not cache_patch:
89+
with register_additional_serialization_functions(patch_transformers=True):
90+
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
91+
if cache_patch:
92+
with torch_export_patches(
93+
patch_torch=cache_patch in ("all", "torch", True, 1),
94+
patch_transformers=cache_patch in ("all", "transformers", True, 1),
95+
):
96+
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
97+
if oblivious:
98+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
99+
return export_model(model, dynamic_shapes, inputs, rt=rt)
100+
return torch.export.export(
101+
model,
102+
(),
103+
inputs,
104+
dynamic_shapes=dynamic_shapes,
105+
prefer_deferred_runtime_asserts_over_guards=rt,
106+
)
107+
108+
109+
def try_export_model(
110+
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
111+
):
112+
try:
113+
return export_model(
114+
model,
115+
dynamic_shapes,
116+
inputs,
117+
cache=cache,
118+
oblivious=oblivious,
119+
rt=rt,
120+
cache_patch=cache_patch,
121+
)
122+
except Exception as e:
123+
return e
124+
125+
126+
def validation(ep, input_sets, expected):
127+
mod = ep.module()
128+
for k, v in input_sets.items():
129+
try:
130+
got = mod(**torch_deepcopy(v))
131+
except Exception as e:
132+
yield k, e
133+
continue
134+
yield k, max_diff(expected[k], got, verbose=0)
135+
136+
137+
# %%
138+
# The main loop
139+
# +++++++++++++
140+
141+
results = []
142+
143+
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
144+
possibilities[1] = [0, "all", "torch", "transformers"]
145+
with tqdm(list(itertools.product(*possibilities))) as pbar:
146+
for cache, cache_patch, oblivious, rt, inputs in pbar:
147+
if cache_patch and not cache:
148+
# patches include caches.
149+
continue
150+
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
151+
legend = "-".join(
152+
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
153+
)
154+
legend = f"{legend}/{inputs}"
155+
pbar.set_description(f"{legend} EXPORT")
156+
157+
# export
158+
ep = try_export_model(
159+
model, dynamic_shapes, torch_deepcopy(input_sets[inputs]), **kwargs
160+
)
161+
if isinstance(ep, Exception):
162+
obs = {
163+
**kwargs,
164+
"export_with": inputs,
165+
"EXPORT": 0,
166+
"ERR-EXPORT": str(ep).split("\n")[0],
167+
}
168+
results.append(obs)
169+
continue
170+
171+
pbar.set_description(f"{legend} VALIDATE")
172+
common = {**kwargs, "export_with": inputs, "EXPORT": 1}
173+
for inp, res in validation(ep, input_sets, expected):
174+
if isinstance(res, Exception):
175+
obs = {
176+
**common,
177+
"run_with": inp,
178+
"ERR-RUN": str(res).split("\n")[0],
179+
"WORKS": 0,
180+
}
181+
else:
182+
obs = {
183+
**common,
184+
"run_with": inp,
185+
"WORKS": int(~np.isnan(res["abs"]) and res["abs"] < 1e-3),
186+
}
187+
results.append(obs)
188+
189+
# %%
190+
# Let's save the results.
191+
192+
df = pandas.DataFrame(results)
193+
df.to_excel("plot_export_tiny_llm_dim01.xlsx")
194+
df
195+
196+
# %% The export failures.
197+
198+
no_export = df[df.EXPORT == 0]
199+
no_export.to_excel("plot_export_tiny_llm_dim01.no_export.xlsx")
200+
no_export
201+
202+
# %%
203+
# The validation failures.
204+
205+
invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
206+
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
207+
columns=["run_with"],
208+
values=["WORKS", "ERR-RUN"],
209+
)
210+
invalid.to_excel("plot_export_tiny_llm_dim01.invalid.xlsx")
211+
invalid
212+
213+
# %% Successes.
214+
215+
success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
216+
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
217+
columns=["run_with"],
218+
values=["WORKS"],
219+
)
220+
success.to_excel("plot_export_tiny_llm_dim01.success.xlsx")
221+
success
222+
223+
224+
# %%
225+
# If you have any error, then look at example
226+
# :ref:`l-plot-tiny-llm-export-patched`.
227+
228+
doc.plot_legend("Tiny-LLM\nexport with\ndimension in {0,1}", "torch.export.export", "tomato")

0 commit comments

Comments
 (0)