Skip to content

Commit a3f9f83

Browse files
authored
tune examples (#255)
* tune examples * fixe
1 parent 651a43e commit a3f9f83

File tree

3 files changed

+64
-25
lines changed

3 files changed

+64
-25
lines changed

_doc/examples/plot_export_tiny_llm_dim01.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,31 +83,50 @@
8383

8484

8585
def export_model(
86-
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
86+
model,
87+
dynamic_shapes,
88+
inputs,
89+
cache=False,
90+
oblivious=False,
91+
rt=False,
92+
cache_patch=False,
93+
strict=False,
8794
):
8895
if cache and not cache_patch:
8996
with register_additional_serialization_functions(patch_transformers=True):
90-
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
97+
return export_model(
98+
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
99+
)
91100
if cache_patch:
92101
with torch_export_patches(
93102
patch_torch=cache_patch in ("all", "torch", True, 1),
94103
patch_transformers=cache_patch in ("all", "transformers", True, 1),
95104
):
96-
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
105+
return export_model(
106+
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
107+
)
97108
if oblivious:
98109
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
99-
return export_model(model, dynamic_shapes, inputs, rt=rt)
110+
return export_model(model, dynamic_shapes, inputs, rt=rt, strict=strict)
100111
return torch.export.export(
101112
model,
102113
(),
103114
inputs,
104115
dynamic_shapes=dynamic_shapes,
116+
strict=strict,
105117
prefer_deferred_runtime_asserts_over_guards=rt,
106118
)
107119

108120

109121
def try_export_model(
110-
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
122+
model,
123+
dynamic_shapes,
124+
inputs,
125+
cache=False,
126+
oblivious=False,
127+
rt=False,
128+
cache_patch=False,
129+
strict=False,
111130
):
112131
try:
113132
return export_model(
@@ -118,6 +137,7 @@ def try_export_model(
118137
oblivious=oblivious,
119138
rt=rt,
120139
cache_patch=cache_patch,
140+
strict=strict,
121141
)
122142
except Exception as e:
123143
return e
@@ -140,14 +160,16 @@ def validation(ep, input_sets, expected):
140160

141161
results = []
142162

143-
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
163+
possibilities = [*[[0, 1] for _ in range(5)], list(input_sets)]
144164
possibilities[1] = [0, "all", "torch", "transformers"]
145165
with tqdm(list(itertools.product(*possibilities))) as pbar:
146-
for cache, cache_patch, oblivious, rt, inputs in pbar:
166+
for cache, cache_patch, strict, oblivious, rt, inputs in pbar:
147167
if cache_patch and not cache:
148168
# patches include caches.
149169
continue
150-
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
170+
kwargs = dict(
171+
cache=cache, cache_patch=cache_patch, strict=strict, oblivious=oblivious, rt=rt
172+
)
151173
legend = "-".join(
152174
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
153175
)
@@ -203,7 +225,7 @@ def validation(ep, input_sets, expected):
203225
# The validation failures.
204226

205227
invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
206-
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
228+
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
207229
columns=["run_with"],
208230
values=["WORKS", "ERR-RUN"],
209231
)
@@ -213,7 +235,7 @@ def validation(ep, input_sets, expected):
213235
# %% Successes.
214236

215237
success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
216-
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
238+
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
217239
columns=["run_with"],
218240
values=["WORKS"],
219241
)

_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,28 @@
7777

7878

7979
def export_model(
80-
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
80+
model,
81+
dynamic_shapes,
82+
inputs,
83+
cache=False,
84+
oblivious=False,
85+
rt=False,
86+
cache_patch=False,
87+
strict=False,
8188
):
8289
if cache and not cache_patch:
8390
with register_additional_serialization_functions(patch_transformers=True):
84-
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
91+
return export_model(
92+
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
93+
)
8594
if cache_patch:
8695
with torch_export_patches(
8796
patch_torch=cache_patch in ("all", "torch", True, 1),
8897
patch_transformers=cache_patch in ("all", "transformers", True, 1),
8998
):
90-
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
99+
return export_model(
100+
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
101+
)
91102
return to_onnx(
92103
model,
93104
(),
@@ -96,12 +107,20 @@ def export_model(
96107
export_options=ExportOptions(
97108
prefer_deferred_runtime_asserts_over_guards=rt,
98109
backed_size_oblivious=oblivious,
110+
strict=strict,
99111
),
100112
)
101113

102114

103115
def try_export_model(
104-
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
116+
model,
117+
dynamic_shapes,
118+
inputs,
119+
cache=False,
120+
oblivious=False,
121+
rt=False,
122+
cache_patch=False,
123+
strict=False,
105124
):
106125
try:
107126
return export_model(
@@ -112,6 +131,7 @@ def try_export_model(
112131
oblivious=oblivious,
113132
rt=rt,
114133
cache_patch=cache_patch,
134+
strict=strict,
115135
)
116136
except Exception as e:
117137
return e
@@ -155,16 +175,19 @@ def validation(onx, input_sets, expected, catch_exception=True):
155175
possibilities = [
156176
[0, 1],
157177
[0, "all", "torch", "transformers"],
178+
[0, 1],
158179
[0, 1, "auto", "half"],
159180
[0, 1],
160181
list(input_sets),
161182
]
162183
with tqdm(list(itertools.product(*possibilities))) as pbar:
163-
for cache, cache_patch, oblivious, rt, inputs in pbar:
184+
for cache, cache_patch, strict, oblivious, rt, inputs in pbar:
164185
if cache_patch and not cache:
165186
# patches include caches.
166187
continue
167-
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
188+
kwargs = dict(
189+
cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt, strict=strict
190+
)
168191
legend = "-".join(
169192
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
170193
)
@@ -220,7 +243,7 @@ def validation(onx, input_sets, expected, catch_exception=True):
220243
# The validation failures.
221244

222245
invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
223-
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
246+
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
224247
columns=["run_with"],
225248
values=["WORKS", "ERR-RUN"],
226249
)
@@ -230,7 +253,7 @@ def validation(onx, input_sets, expected, catch_exception=True):
230253
# %% Successes.
231254

232255
success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
233-
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
256+
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
234257
columns=["run_with"],
235258
values=["WORKS"],
236259
)

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,6 @@ def add_test_methods(cls):
109109
):
110110
reason = "torch<2.8"
111111

112-
if (
113-
not reason
114-
and name in {"plot_export_tiny_llm_dim01.py"}
115-
and not has_torch("2.9")
116-
):
117-
reason = "torch<2.9"
118-
119112
if (
120113
not reason
121114
and name in {"plot_dump_intermediate_results.py"}
@@ -131,6 +124,7 @@ def add_test_methods(cls):
131124
reason = "unstable, let's wait for the next version"
132125

133126
if not reason and name in {
127+
"plot_export_tiny_llm_dim01.py",
134128
"plot_export_tiny_llm_dim01_onnx.py",
135129
"plot_export_tiny_llm_dim01_onnx_custom.py",
136130
}:

0 commit comments

Comments
 (0)