Skip to content

Commit c573072

Browse files
committed
adds more options
1 parent 7fce5ae commit c573072

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

_doc/examples/plot_export_tiny_llm_dim01.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def export_model(
8989
with register_additional_serialization_functions(patch_transformers=True):
9090
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
9191
if cache_patch:
92-
with torch_export_patches(patch_transformers=True):
92+
with torch_export_patches(
93+
patch_torch=cache_patch in ("all", "torch", True, 1),
94+
patch_transformers=cache_patch in ("all", "transformers", True, 1),
95+
):
9396
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
9497
if oblivious:
9598
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
@@ -138,13 +141,16 @@ def validation(ep, input_sets, expected):
138141
results = []
139142

140143
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
144+
possibilities[1] = [0, "all", "torch", "transformers"]
141145
with tqdm(list(itertools.product(*possibilities))) as pbar:
142146
for cache, cache_patch, oblivious, rt, inputs in pbar:
143147
if cache_patch and not cache:
144148
# patches include caches.
145149
continue
146150
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
147-
legend = "-".join(k for k, v in kwargs.items() if v)
151+
legend = "-".join(
152+
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
153+
)
148154
legend = f"{legend}/{inputs}"
149155
pbar.set_description(f"{legend} EXPORT")
150156

_doc/examples/plot_export_tiny_llm_dim01_onnx.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ def export_model(
8484
with register_additional_serialization_functions(patch_transformers=True):
8585
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
8686
if cache_patch:
87-
with torch_export_patches(patch_transformers=True):
87+
with torch_export_patches(
88+
patch_torch=cache_patch in ("all", "torch", True, 1),
89+
patch_transformers=cache_patch in ("all", "transformers", True, 1),
90+
):
8891
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
8992
if oblivious:
9093
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
@@ -152,13 +155,16 @@ def validation(ep, input_sets, expected, catch_exception=True):
152155
results = []
153156

154157
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
158+
possibilities[1] = [0, "all", "torch", "transformers"]
155159
with tqdm(list(itertools.product(*possibilities))) as pbar:
156160
for cache, cache_patch, oblivious, rt, inputs in pbar:
157161
if cache_patch and not cache:
158162
# patches include caches.
159163
continue
160164
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
161-
legend = "-".join(k for k, v in kwargs.items() if v)
165+
legend = "-".join(
166+
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
167+
)
162168
legend = f"{legend}/{inputs}"
163169
pbar.set_description(f"{legend} EXPORT")
164170

_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def export_model(
8383
with register_additional_serialization_functions(patch_transformers=True):
8484
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
8585
if cache_patch:
86-
with torch_export_patches(patch_transformers=True):
86+
with torch_export_patches(
87+
patch_torch=cache_patch in ("all", "torch", True, 1),
88+
patch_transformers=cache_patch in ("all", "transformers", True, 1),
89+
):
8790
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
8891
return to_onnx(
8992
model,
@@ -149,14 +152,22 @@ def validation(onx, input_sets, expected, catch_exception=True):
149152

150153
results = []
151154

152-
possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
155+
possibilities = [
156+
[0, 1],
157+
[0, "all", "torch", "transformers"],
158+
[0, 1, "auto", "half"],
159+
[0, 1],
160+
list(input_sets),
161+
]
153162
with tqdm(list(itertools.product(*possibilities))) as pbar:
154163
for cache, cache_patch, oblivious, rt, inputs in pbar:
155164
if cache_patch and not cache:
156165
# patches include caches.
157166
continue
158167
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
159-
legend = "-".join(k for k, v in kwargs.items() if v)
168+
legend = "-".join(
169+
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
170+
)
160171
legend = f"{legend}/{inputs}"
161172
pbar.set_description(f"{legend} EXPORT")
162173

0 commit comments

Comments
 (0)