Skip to content

Commit 1bd84c5

Browse files
authored
Add os_ort optimization (#77)
* frsit draft * add support for os_ort * pass * fix issues * ut * clean
1 parent 7eb831c commit 1bd84c5

File tree

6 files changed

+109
-19
lines changed

6 files changed

+109
-19
lines changed

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101

102102
# %%
103103
# If they are not registered, function
104-
# func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
104+
# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
105105
# should take care of it. Then we export.
106106

107107
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Size of the package:
173173
Older versions
174174
++++++++++++++
175175

176+
* `0.4.4 <../v0.4.4/index.html>`_
176177
* `0.4.3 <../v0.4.3/index.html>`_
177178
* `0.4.2 <../v0.4.2/index.html>`_
178179
* `0.4.1 <../v0.4.1/index.html>`_

_doc/recipes/plot_dynamic_shapes_max.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,6 @@ def forward(self, x, y, fact):
185185
# is hidden in a custom operator.
186186

187187

188-
doc.plot_legend("max(d1, d2)\nwith d1, d2 dimensions", "dynamic shapes", "green")
188+
doc.plot_legend(
189+
"Fixed in torch==2.8\nmax(d1, d2)\nwith d1, d2\ndimensions", "dynamic shapes", "green"
190+
)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_validate_model_export(self):
6666
@requires_torch("2.7")
6767
@hide_stdout()
6868
@ignore_warnings(FutureWarning)
69-
def test_validate_model_onnx_dynamo(self):
69+
def test_validate_model_onnx_dynamo_ir(self):
7070
mid = "arnir0/Tiny-LLM"
7171
summary, data = validate_model(
7272
mid,
@@ -87,6 +87,49 @@ def test_validate_model_onnx_dynamo(self):
8787
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
8888
)
8989

90+
@requires_torch("2.7")
91+
@hide_stdout()
92+
@ignore_warnings(FutureWarning)
93+
def test_validate_model_onnx_dynamo_os_ort(self):
94+
mid = "arnir0/Tiny-LLM"
95+
summary, data = validate_model(
96+
mid,
97+
do_run=True,
98+
verbose=10,
99+
exporter="onnx-dynamo",
100+
dump_folder="dump_test_validate_model_onnx_dynamo",
101+
patch=True,
102+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
103+
optimization="os_ort",
104+
)
105+
self.assertIsInstance(summary, dict)
106+
self.assertIsInstance(data, dict)
107+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
108+
onnx_filename = data["onnx_filename"]
109+
self.assertExists(onnx_filename)
110+
111+
@requires_torch("2.7")
112+
@hide_stdout()
113+
@ignore_warnings(FutureWarning)
114+
@requires_experimental()
115+
def test_validate_model_custom_os_ort(self):
116+
mid = "arnir0/Tiny-LLM"
117+
summary, data = validate_model(
118+
mid,
119+
do_run=True,
120+
verbose=10,
121+
exporter="custom",
122+
dump_folder="dump_validate_model_custom_os_ort",
123+
patch=True,
124+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
125+
optimization="default+os_ort",
126+
)
127+
self.assertIsInstance(summary, dict)
128+
self.assertIsInstance(data, dict)
129+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
130+
onnx_filename = data["onnx_filename"]
131+
self.assertExists(onnx_filename)
132+
90133
@requires_torch("2.7")
91134
@hide_stdout()
92135
@ignore_warnings(FutureWarning)

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.4.3"
6+
__version__ = "0.4.4"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55
import time
66
import onnx
7+
import onnxscript
8+
import onnxscript.rewriter.ort_fusions as ort_fusions
79
import torch
810
from ..export import CoupleInputsDynamicShapes
911
from ..helpers import max_diff, string_type, string_diff
@@ -917,11 +919,10 @@ def call_torch_export_onnx(
917919
:return: two dictionaries, one with some metrics,
918920
another one with whatever the function produces
919921
"""
920-
assert optimization in {
921-
"",
922-
"ir",
923-
None,
924-
}, f"unexpected value for optimization={optimization}"
922+
available = {"", "ir", "os_ort"}
923+
assert (
924+
optimization in available
925+
), f"unexpected value for optimization={optimization}, available={available}"
925926
assert exporter in {
926927
"onnx-dynamo",
927928
"onnx-script",
@@ -1001,16 +1002,25 @@ def call_torch_export_onnx(
10011002
print(epo)
10021003
print("[call_torch_export_onnx] -- End of ONNXProgram")
10031004

1004-
if optimization == "ir":
1005+
if optimization in {"ir", "os_ort"}:
10051006
if verbose:
10061007
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1007-
_quiet_or_not_quiet(
1008-
quiet,
1009-
"export_onnx_opt_ir",
1010-
summary,
1011-
data,
1012-
(lambda epo=epo: epo.optimize()),
1013-
)
1008+
if optimization == "ir":
1009+
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1010+
else:
1011+
1012+
def _os_ort_optim(epo):
1013+
onnxscript.optimizer.optimize_ir(epo.model)
1014+
optimized = ort_fusions.optimize_for_ort(epo.model)
1015+
if isinstance(optimized, tuple):
1016+
for k, v in optimized[1].items():
1017+
summary[f"op_opt_fused_{k}"] = v
1018+
epo.model = optimized[0]
1019+
else:
1020+
epo.model = optimized
1021+
1022+
label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
1023+
_quiet_or_not_quiet(quiet, label, summary, data, f_optim)
10141024
if "ERR_export_onnx_opt_ir" in summary:
10151025
return summary, data
10161026
if verbose:
@@ -1039,12 +1049,17 @@ def call_torch_export_custom(
10391049
:return: two dictionaries, one with some metrics,
10401050
another one with whatever the function produces
10411051
"""
1042-
assert optimization in {
1052+
available = {
10431053
"",
10441054
"default",
10451055
"default+onnxruntime",
1056+
"default+os_ort",
1057+
"default+onnxruntime+os_ort",
10461058
None,
1047-
}, f"unexpected value for optimization={optimization}"
1059+
}
1060+
assert (
1061+
optimization in available
1062+
), f"unexpected value for optimization={optimization}, available={available}"
10481063
assert exporter in {
10491064
"custom",
10501065
"custom-strict",
@@ -1078,6 +1093,10 @@ def call_torch_export_custom(
10781093
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
10791094
from experimental_experiment.xbuilder import OptimizationOptions
10801095

1096+
spl = optimization.split("+") if optimization else []
1097+
os_ort = "os_ort" in spl
1098+
optimization = "+".join(_ for _ in spl if _ != "os_ort")
1099+
10811100
export_options = ExportOptions(
10821101
strict=strict,
10831102
decomposition_table=(
@@ -1181,6 +1200,31 @@ def call_torch_export_custom(
11811200
assert epo is not None, "no onnx export was found"
11821201
if verbose:
11831202
print("[call_torch_export_custom] done (export)")
1203+
1204+
if os_ort:
1205+
if verbose:
1206+
print("[call_torch_export_custom] conversion to IR...")
1207+
begin = time.perf_counter()
1208+
ir_model = epo.to_ir()
1209+
duration = time.perf_counter() - begin
1210+
summary["time_optim_to_ir"] = duration
1211+
if verbose:
1212+
print(f"[call_torch_export_custom] done in {duration}")
1213+
print("[call_torch_export_custom] start optimization...")
1214+
begin = time.perf_counter()
1215+
onnxscript.optimizer.optimize_ir(ir_model)
1216+
ir_optimized = ort_fusions.optimize_for_ort(ir_model)
1217+
if isinstance(ir_optimized, tuple):
1218+
report = ir_optimized[1]
1219+
for k, v in report.items():
1220+
summary[f"op_opt_fused_{k}"] = v
1221+
ir_optimized = ir_optimized[0]
1222+
epo.model = ir_optimized
1223+
duration = time.perf_counter() - begin
1224+
summary["time_optim_os_ort"] = duration
1225+
if verbose:
1226+
print(f"[call_torch_export_custom] done in {duration}")
1227+
11841228
data["onnx_program"] = epo
11851229
return summary, data
11861230

0 commit comments

Comments
 (0)