Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions _doc/examples/plot_export_tiny_llm_dim01.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,50 @@


def export_model(
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
model,
dynamic_shapes,
inputs,
cache=False,
oblivious=False,
rt=False,
cache_patch=False,
strict=False,
):
if cache and not cache_patch:
with register_additional_serialization_functions(patch_transformers=True):
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
return export_model(
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
)
if cache_patch:
with torch_export_patches(
patch_torch=cache_patch in ("all", "torch", True, 1),
patch_transformers=cache_patch in ("all", "transformers", True, 1),
):
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
return export_model(
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
)
if oblivious:
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
return export_model(model, dynamic_shapes, inputs, rt=rt)
return export_model(model, dynamic_shapes, inputs, rt=rt, strict=strict)
return torch.export.export(
model,
(),
inputs,
dynamic_shapes=dynamic_shapes,
strict=strict,
prefer_deferred_runtime_asserts_over_guards=rt,
)


def try_export_model(
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
model,
dynamic_shapes,
inputs,
cache=False,
oblivious=False,
rt=False,
cache_patch=False,
strict=False,
):
try:
return export_model(
Expand All @@ -118,6 +137,7 @@ def try_export_model(
oblivious=oblivious,
rt=rt,
cache_patch=cache_patch,
strict=strict,
)
except Exception as e:
return e
Expand All @@ -140,14 +160,16 @@ def validation(ep, input_sets, expected):

results = []

possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)]
possibilities = [*[[0, 1] for _ in range(5)], list(input_sets)]
possibilities[1] = [0, "all", "torch", "transformers"]
with tqdm(list(itertools.product(*possibilities))) as pbar:
for cache, cache_patch, oblivious, rt, inputs in pbar:
for cache, cache_patch, strict, oblivious, rt, inputs in pbar:
if cache_patch and not cache:
# patches include caches.
continue
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
kwargs = dict(
cache=cache, cache_patch=cache_patch, strict=strict, oblivious=oblivious, rt=rt
)
legend = "-".join(
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
)
Expand Down Expand Up @@ -203,7 +225,7 @@ def validation(ep, input_sets, expected):
# The validation failures.

invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
columns=["run_with"],
values=["WORKS", "ERR-RUN"],
)
Expand All @@ -213,7 +235,7 @@ def validation(ep, input_sets, expected):
# %% Successes.

success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
columns=["run_with"],
values=["WORKS"],
)
Expand Down
39 changes: 31 additions & 8 deletions _doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,28 @@


def export_model(
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
model,
dynamic_shapes,
inputs,
cache=False,
oblivious=False,
rt=False,
cache_patch=False,
strict=False,
):
if cache and not cache_patch:
with register_additional_serialization_functions(patch_transformers=True):
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
return export_model(
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
)
if cache_patch:
with torch_export_patches(
patch_torch=cache_patch in ("all", "torch", True, 1),
patch_transformers=cache_patch in ("all", "transformers", True, 1),
):
return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt)
return export_model(
model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict
)
return to_onnx(
model,
(),
Expand All @@ -96,12 +107,20 @@ def export_model(
export_options=ExportOptions(
prefer_deferred_runtime_asserts_over_guards=rt,
backed_size_oblivious=oblivious,
strict=strict,
),
)


def try_export_model(
model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False
model,
dynamic_shapes,
inputs,
cache=False,
oblivious=False,
rt=False,
cache_patch=False,
strict=False,
):
try:
return export_model(
Expand All @@ -112,6 +131,7 @@ def try_export_model(
oblivious=oblivious,
rt=rt,
cache_patch=cache_patch,
strict=strict,
)
except Exception as e:
return e
Expand Down Expand Up @@ -155,16 +175,19 @@ def validation(onx, input_sets, expected, catch_exception=True):
possibilities = [
[0, 1],
[0, "all", "torch", "transformers"],
[0, 1],
[0, 1, "auto", "half"],
[0, 1],
list(input_sets),
]
with tqdm(list(itertools.product(*possibilities))) as pbar:
for cache, cache_patch, oblivious, rt, inputs in pbar:
for cache, cache_patch, strict, oblivious, rt, inputs in pbar:
if cache_patch and not cache:
# patches include caches.
continue
kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt)
kwargs = dict(
cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt, strict=strict
)
legend = "-".join(
(k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v
)
Expand Down Expand Up @@ -220,7 +243,7 @@ def validation(onx, input_sets, expected, catch_exception=True):
# The validation failures.

invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot(
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
columns=["run_with"],
values=["WORKS", "ERR-RUN"],
)
Expand All @@ -230,7 +253,7 @@ def validation(onx, input_sets, expected, catch_exception=True):
# %% Successes.

success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot(
index=["cache", "cache_patch", "oblivious", "rt", "export_with"],
index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"],
columns=["run_with"],
values=["WORKS"],
)
Expand Down
8 changes: 1 addition & 7 deletions _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,6 @@ def add_test_methods(cls):
):
reason = "torch<2.8"

if (
not reason
and name in {"plot_export_tiny_llm_dim01.py"}
and not has_torch("2.9")
):
reason = "torch<2.9"

if (
not reason
and name in {"plot_dump_intermediate_results.py"}
Expand All @@ -131,6 +124,7 @@ def add_test_methods(cls):
reason = "unstable, let's wait for the next version"

if not reason and name in {
"plot_export_tiny_llm_dim01.py",
"plot_export_tiny_llm_dim01_onnx.py",
"plot_export_tiny_llm_dim01_onnx_custom.py",
}:
Expand Down
Loading