diff --git a/_doc/examples/plot_export_tiny_llm_dim01.py b/_doc/examples/plot_export_tiny_llm_dim01.py index fb2e477b..872eb5b7 100644 --- a/_doc/examples/plot_export_tiny_llm_dim01.py +++ b/_doc/examples/plot_export_tiny_llm_dim01.py @@ -83,31 +83,50 @@ def export_model( - model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False + model, + dynamic_shapes, + inputs, + cache=False, + oblivious=False, + rt=False, + cache_patch=False, + strict=False, ): if cache and not cache_patch: with register_additional_serialization_functions(patch_transformers=True): - return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt) + return export_model( + model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict + ) if cache_patch: with torch_export_patches( patch_torch=cache_patch in ("all", "torch", True, 1), patch_transformers=cache_patch in ("all", "transformers", True, 1), ): - return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt) + return export_model( + model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict + ) if oblivious: with torch.fx.experimental._config.patch(backed_size_oblivious=True): - return export_model(model, dynamic_shapes, inputs, rt=rt) + return export_model(model, dynamic_shapes, inputs, rt=rt, strict=strict) return torch.export.export( model, (), inputs, dynamic_shapes=dynamic_shapes, + strict=strict, prefer_deferred_runtime_asserts_over_guards=rt, ) def try_export_model( - model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False + model, + dynamic_shapes, + inputs, + cache=False, + oblivious=False, + rt=False, + cache_patch=False, + strict=False, ): try: return export_model( @@ -118,6 +137,7 @@ def try_export_model( oblivious=oblivious, rt=rt, cache_patch=cache_patch, + strict=strict, ) except Exception as e: return e @@ -140,14 +160,16 @@ def validation(ep, input_sets, expected): results = [] -possibilities = [*[[0, 1] for _ in range(4)], list(input_sets)] +possibilities = [*[[0, 1] for _ in range(5)], list(input_sets)] possibilities[1] = [0, "all", "torch", "transformers"] with tqdm(list(itertools.product(*possibilities))) as pbar: - for cache, cache_patch, oblivious, rt, inputs in pbar: + for cache, cache_patch, strict, oblivious, rt, inputs in pbar: if cache_patch and not cache: # patches include caches. continue - kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt) + kwargs = dict( + cache=cache, cache_patch=cache_patch, strict=strict, oblivious=oblivious, rt=rt + ) legend = "-".join( (k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v ) @@ -203,7 +225,7 @@ def validation(ep, input_sets, expected): # The validation failures. invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot( - index=["cache", "cache_patch", "oblivious", "rt", "export_with"], + index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"], columns=["run_with"], values=["WORKS", "ERR-RUN"], ) @@ -213,7 +235,7 @@ def validation(ep, input_sets, expected): # %% Successes. success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot( - index=["cache", "cache_patch", "oblivious", "rt", "export_with"], + index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"], columns=["run_with"], values=["WORKS"], ) diff --git a/_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py b/_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py index 34d10a75..02d79dfa 100644 --- a/_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py +++ b/_doc/examples/plot_export_tiny_llm_dim01_onnx_custom.py @@ -77,17 +77,28 @@ def export_model( - model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False + model, + dynamic_shapes, + inputs, + cache=False, + oblivious=False, + rt=False, + cache_patch=False, + strict=False, ): if cache and not cache_patch: with register_additional_serialization_functions(patch_transformers=True): - return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt) + return export_model( + model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict + ) if cache_patch: with torch_export_patches( patch_torch=cache_patch in ("all", "torch", True, 1), patch_transformers=cache_patch in ("all", "transformers", True, 1), ): - return export_model(model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt) + return export_model( + model, dynamic_shapes, inputs, oblivious=oblivious, rt=rt, strict=strict + ) return to_onnx( model, (), @@ -96,12 +107,20 @@ def export_model( export_options=ExportOptions( prefer_deferred_runtime_asserts_over_guards=rt, backed_size_oblivious=oblivious, + strict=strict, ), ) def try_export_model( - model, dynamic_shapes, inputs, cache=False, oblivious=False, rt=False, cache_patch=False + model, + dynamic_shapes, + inputs, + cache=False, + oblivious=False, + rt=False, + cache_patch=False, + strict=False, ): try: return export_model( @@ -112,6 +131,7 @@ def try_export_model( oblivious=oblivious, rt=rt, cache_patch=cache_patch, + strict=strict, ) except Exception as e: return e @@ -155,16 +175,19 @@ def validation(onx, input_sets, expected, catch_exception=True): possibilities = [ [0, 1], [0, "all", "torch", "transformers"], + [0, 1], [0, 1, "auto", "half"], [0, 1], list(input_sets), ] with tqdm(list(itertools.product(*possibilities))) as pbar: - for cache, cache_patch, oblivious, rt, inputs in pbar: + for cache, cache_patch, strict, oblivious, rt, inputs in pbar: if cache_patch and not cache: # patches include caches. continue - kwargs = dict(cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt) + kwargs = dict( + cache=cache, cache_patch=cache_patch, oblivious=oblivious, rt=rt, strict=strict + ) legend = "-".join( (k if isinstance(v, int) else f"{k}:{v}") for k, v in kwargs.items() if v ) @@ -220,7 +243,7 @@ def validation(onx, input_sets, expected, catch_exception=True): # The validation failures. invalid = df[(df.EXPORT == 1) & (df.WORKS == 0)].pivot( - index=["cache", "cache_patch", "oblivious", "rt", "export_with"], + index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"], columns=["run_with"], values=["WORKS", "ERR-RUN"], ) @@ -230,7 +253,7 @@ def validation(onx, input_sets, expected, catch_exception=True): # %% Successes. success = df[(df.EXPORT == 1) & (df.WORKS == 1)].pivot( - index=["cache", "cache_patch", "oblivious", "rt", "export_with"], + index=["cache", "cache_patch", "strict", "oblivious", "rt", "export_with"], columns=["run_with"], values=["WORKS"], ) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 4d79512f..0dc45f6f 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -109,13 +109,6 @@ def add_test_methods(cls): ): reason = "torch<2.8" - if ( - not reason - and name in {"plot_export_tiny_llm_dim01.py"} - and not has_torch("2.9") - ): - reason = "torch<2.9" - if ( not reason and name in {"plot_dump_intermediate_results.py"} @@ -131,6 +124,7 @@ def add_test_methods(cls): reason = "unstable, let's wait for the next version" if not reason and name in { + "plot_export_tiny_llm_dim01.py", "plot_export_tiny_llm_dim01_onnx.py", "plot_export_tiny_llm_dim01_onnx_custom.py", }: