Skip to content

Commit 05b04f7

Browse files
committed
add support for os_ort
1 parent aff1365 commit 05b04f7

File tree

6 files changed

+48
-17
lines changed

6 files changed

+48
-17
lines changed

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101

102102
# %%
103103
# If they are not registered, function
104-
# func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
104+
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
105105
# should take care of it. Then we export.
106106

107107
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Size of the package:
173173
Older versions
174174
++++++++++++++
175175

176+
* `0.4.4 <../v0.4.4/index.html>`_
176177
* `0.4.3 <../v0.4.3/index.html>`_
177178
* `0.4.2 <../v0.4.2/index.html>`_
178179
* `0.4.1 <../v0.4.1/index.html>`_

_doc/recipes/plot_dynamic_shapes_max.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,6 @@ def forward(self, x, y, fact):
185185
# is hidden in a custom operator.
186186

187187

188-
doc.plot_legend("max(d1, d2)\nwith d1, d2 dimensions", "dynamic shapes", "green")
188+
doc.plot_legend(
189+
"Fixed in torch==2.8\nmax(d1, d2)\nwith d1, d2\ndimensions", "dynamic shapes", "green"
190+
)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_validate_model_export(self):
6666
@requires_torch("2.7")
6767
@hide_stdout()
6868
@ignore_warnings(FutureWarning)
69-
def test_validate_model_onnx_dynamo(self):
69+
def test_validate_model_onnx_dynamo_ir(self):
7070
mid = "arnir0/Tiny-LLM"
7171
summary, data = validate_model(
7272
mid,
@@ -87,6 +87,27 @@ def test_validate_model_onnx_dynamo(self):
8787
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
8888
)
8989

90+
@requires_torch("2.7")
91+
@hide_stdout()
92+
@ignore_warnings(FutureWarning)
93+
def test_validate_model_onnx_dynamo_os_ort(self):
94+
mid = "arnir0/Tiny-LLM"
95+
summary, data = validate_model(
96+
mid,
97+
do_run=True,
98+
verbose=10,
99+
exporter="onnx-dynamo",
100+
dump_folder="dump_test_validate_model_onnx_dynamo",
101+
patch=True,
102+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
103+
optimization="os_ort",
104+
)
105+
self.assertIsInstance(summary, dict)
106+
self.assertIsInstance(data, dict)
107+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
108+
onnx_filename = data["onnx_filename"]
109+
self.assertExists(onnx_filename)
110+
90111
@requires_torch("2.7")
91112
@hide_stdout()
92113
@ignore_warnings(FutureWarning)

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.4.3"
6+
__version__ = "0.4.4"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55
import time
66
import onnx
7+
import onnxscript
8+
import onnxscript.rewriter.ort_fusions as ort_fusions
79
import torch
810
from ..export import CoupleInputsDynamicShapes
911
from ..helpers import max_diff, string_type, string_diff
@@ -917,11 +919,10 @@ def call_torch_export_onnx(
917919
:return: two dictionaries, one with some metrics,
918920
another one with whatever the function produces
919921
"""
920-
assert optimization in {
921-
"",
922-
"ir",
923-
None,
924-
}, f"unexpected value for optimization={optimization}"
922+
available = {"", "ir", "os_ort"}
923+
assert (
924+
optimization in available
925+
), f"unexpected value for optimization={optimization}, available={available}"
925926
assert exporter in {
926927
"onnx-dynamo",
927928
"onnx-script",
@@ -1001,16 +1002,22 @@ def call_torch_export_onnx(
10011002
print(epo)
10021003
print("[call_torch_export_onnx] -- End of ONNXProgram")
10031004

1004-
if optimization == "ir":
1005+
if optimization in {"ir", "os_ort"}:
10051006
if verbose:
10061007
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1007-
_quiet_or_not_quiet(
1008-
quiet,
1009-
"export_onnx_opt_ir",
1010-
summary,
1011-
data,
1012-
(lambda epo=epo: epo.optimize()),
1013-
)
1008+
if optimization == "ir":
1009+
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1010+
else:
1011+
1012+
def _os_ort_optim(epo):
1013+
onnxscript.optimizer.optimize_ir(epo.model)
1014+
optimized = ort_fusions.optimize_for_ort(epo.model)
1015+
epo.model = (
1016+
optimized if isinstance(optimized, onnxscript.ir.Model) else optimized[0]
1017+
)
1018+
1019+
label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
1020+
_quiet_or_not_quiet(quiet, label, summary, data, f_optim)
10141021
if "ERR_export_onnx_opt_ir" in summary:
10151022
return summary, data
10161023
if verbose:

0 commit comments

Comments
 (0)