diff --git a/_doc/examples/plot_export_tiny_llm_patched.py b/_doc/examples/plot_export_tiny_llm_patched.py index 60a20d15..5ed9566e 100644 --- a/_doc/examples/plot_export_tiny_llm_patched.py +++ b/_doc/examples/plot_export_tiny_llm_patched.py @@ -101,7 +101,7 @@ # %% # If they are not registered, function -# func:`onnx_diagnostic.torch_export_patches.torch_export_patches` +# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` # should take care of it. Then we export. with torch_export_patches(patch_transformers=True, verbose=10) as modificator: diff --git a/_doc/index.rst b/_doc/index.rst index 6d4792dd..5d105c5e 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -173,6 +173,7 @@ Size of the package: Older versions ++++++++++++++ +* `0.4.4 <../v0.4.4/index.html>`_ * `0.4.3 <../v0.4.3/index.html>`_ * `0.4.2 <../v0.4.2/index.html>`_ * `0.4.1 <../v0.4.1/index.html>`_ diff --git a/_doc/recipes/plot_dynamic_shapes_max.py b/_doc/recipes/plot_dynamic_shapes_max.py index 83844c7c..1d681740 100644 --- a/_doc/recipes/plot_dynamic_shapes_max.py +++ b/_doc/recipes/plot_dynamic_shapes_max.py @@ -185,4 +185,6 @@ def forward(self, x, y, fact): # is hidden in a custom operator. -doc.plot_legend("max(d1, d2)\nwith d1, d2 dimensions", "dynamic shapes", "green") +doc.plot_legend( + "Fixed in torch==2.8\nmax(d1, d2)\nwith d1, d2\ndimensions", "dynamic shapes", "green" +) diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index 1a2a99d9..90dc1b52 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -66,7 +66,7 @@ def test_validate_model_export(self): @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) - def test_validate_model_onnx_dynamo(self): + def test_validate_model_onnx_dynamo_ir(self): mid = "arnir0/Tiny-LLM" summary, data = validate_model( mid, @@ -87,6 +87,49 @@ def test_validate_model_onnx_dynamo(self): onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10 ) + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + def test_validate_model_onnx_dynamo_os_ort(self): + mid = "arnir0/Tiny-LLM" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="onnx-dynamo", + dump_folder="dump_test_validate_model_onnx_dynamo", + patch=True, + stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + optimization="os_ort", + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + @requires_experimental() + def test_validate_model_custom_os_ort(self): + mid = "arnir0/Tiny-LLM" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="custom", + dump_folder="dump_validate_model_custom_os_ort", + patch=True, + stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + optimization="default+os_ort", + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 2d02cd2f..7a45dfe1 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.4.3" +__version__ = "0.4.4" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index aa93ece7..06c907de 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -4,6 +4,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import onnx +import onnxscript +import onnxscript.rewriter.ort_fusions as ort_fusions import torch from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff @@ -917,11 +919,10 @@ def call_torch_export_onnx( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - assert optimization in { - "", - "ir", - None, - }, f"unexpected value for optimization={optimization}" + available = {"", "ir", "os_ort"} + assert ( + optimization in available + ), f"unexpected value for optimization={optimization}, available={available}" assert exporter in { "onnx-dynamo", "onnx-script", @@ -1001,16 +1002,25 @@ def call_torch_export_onnx( print(epo) print("[call_torch_export_onnx] -- End of ONNXProgram") - if optimization == "ir": + if optimization in {"ir", "os_ort"}: if verbose: print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") - _quiet_or_not_quiet( - quiet, - "export_onnx_opt_ir", - summary, - data, - (lambda epo=epo: epo.optimize()), - ) + if optimization == "ir": + label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) + else: + + def _os_ort_optim(epo): + onnxscript.optimizer.optimize_ir(epo.model) + optimized = ort_fusions.optimize_for_ort(epo.model) + if isinstance(optimized, tuple): + for k, v in optimized[1].items(): + summary[f"op_opt_fused_{k}"] = v + epo.model = optimized[0] + else: + epo.model = optimized + + label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo)) + _quiet_or_not_quiet(quiet, label, summary, data, f_optim) if "ERR_export_onnx_opt_ir" in summary: return summary, data if verbose: @@ -1039,12 +1049,17 @@ def call_torch_export_custom( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - assert optimization in { + available = { "", "default", "default+onnxruntime", + "default+os_ort", + "default+onnxruntime+os_ort", None, - }, f"unexpected value for optimization={optimization}" + } + assert ( + optimization in available + ), f"unexpected value for optimization={optimization}, available={available}" assert exporter in { "custom", "custom-strict", @@ -1078,6 +1093,10 @@ def call_torch_export_custom( from experimental_experiment.torch_interpreter import to_onnx, ExportOptions from experimental_experiment.xbuilder import OptimizationOptions + spl = optimization.split("+") if optimization else [] + os_ort = "os_ort" in spl + optimization = "+".join(_ for _ in spl if _ != "os_ort") + export_options = ExportOptions( strict=strict, decomposition_table=( @@ -1181,6 +1200,31 @@ def call_torch_export_custom( assert epo is not None, "no onnx export was found" if verbose: print("[call_torch_export_custom] done (export)") + + if os_ort: + if verbose: + print("[call_torch_export_custom] conversion to IR...") + begin = time.perf_counter() + ir_model = epo.to_ir() + duration = time.perf_counter() - begin + summary["time_optim_to_ir"] = duration + if verbose: + print(f"[call_torch_export_custom] done in {duration}") + print("[call_torch_export_custom] start optimization...") + begin = time.perf_counter() + onnxscript.optimizer.optimize_ir(ir_model) + ir_optimized = ort_fusions.optimize_for_ort(ir_model) + if isinstance(ir_optimized, tuple): + report = ir_optimized[1] + for k, v in report.items(): + summary[f"op_opt_fused_{k}"] = v + ir_optimized = ir_optimized[0] + epo.model = ir_optimized + duration = time.perf_counter() - begin + summary["time_optim_os_ort"] = duration + if verbose: + print(f"[call_torch_export_custom] done in {duration}") + data["onnx_program"] = epo return summary, data