Skip to content

Commit b6ac5b3

Browse files
authored
add command line to optimize a model (#366)
* add command line to optimize a model * req * doc * fix * mypy
1 parent 78d3520 commit b6ac5b3

File tree

11 files changed

+303
-3
lines changed

11 files changed

+303
-3
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.7
55
+++++
66

7+
* :pr:`366`: add command line to optimize a model
78
* :pr:`363`: patch for DynamicDimConstraintPrinter
89
* :pr:`360`, :pr:`364`: preliminary work for phi4
910

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ onnx_diagnostic.helpers
2121
mini_onnx_builder
2222
model_builder_helper
2323
onnx_helper
24+
optim_helper
2425
ort_session
2526
rt_helper
2627
torch_fx_graph_helper

_doc/api/helpers/optim_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.optim_helper
3+
====================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.optim_helper
6+
:members:
7+
:no-undoc-members:

_doc/cmds/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ Command Lines
1010

1111
compare
1212
config
13+
optimize
1314
sbs
1415
validate

_doc/cmds/optimize.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
-m onnx_diagnostic optimize ... optimizes an onnx model
2+
=======================================================
3+
4+
Description
5+
+++++++++++
6+
7+
See :func:`onnx_diagnostic.helpers.optim_helper.optimize_model`.
8+
9+
.. runpython::
10+
11+
from onnx_diagnostic._command_lines_parser import get_parser_optimize
12+
13+
get_parser_optimize().print_help()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
7+
from onnx_diagnostic.helpers.optim_helper import optimize_model
8+
9+
TFLOAT = onnx.TensorProto.FLOAT
10+
11+
12+
class TestOptimHelper(ExtTestCase):
13+
@hide_stdout()
14+
def test_optimize_model(self):
15+
model = oh.make_model(
16+
oh.make_graph(
17+
[
18+
oh.make_node("Shape", ["X"], ["D2"], start=2, end=3),
19+
oh.make_node("Concat", ["I1", "D2"], ["d"], axis=0),
20+
oh.make_node("Reshape", ["X", "d"], ["Y"]),
21+
],
22+
"test",
23+
[oh.make_tensor_value_info("X", TFLOAT, [2, 3, "d"])],
24+
[oh.make_tensor_value_info("Y", TFLOAT, [6, "d"])],
25+
[onh.from_array(np.array([-1], dtype=np.int64), name="I1")],
26+
),
27+
opset_imports=[oh.make_operatorsetid("", 18)],
28+
ir_version=10,
29+
)
30+
filename = self.dump_onnx("test_optimize_model.onnx", model)
31+
for algo in ["default", "default+onnxruntime", "ir", "os_ort", "slim"]:
32+
output = self.get_dump_file(f"test_optimize_model.{algo}.onnx")
33+
with self.subTest(algo=algo):
34+
optimize_model(algo, filename, output=output, verbose=1)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_parser_dot,
1111
get_parser_find,
1212
get_parser_lighten,
13+
get_parser_optimize,
1314
get_parser_print,
1415
get_parser_sbs,
1516
get_parser_stats,
@@ -178,6 +179,13 @@ def test_parser_compare(self):
178179
text = st.getvalue()
179180
self.assertIn("compare", text)
180181

182+
def test_parser_optimize(self):
183+
st = StringIO()
184+
with redirect_stdout(st):
185+
get_parser_optimize().print_help()
186+
text = st.getvalue()
187+
self.assertIn("default", text)
188+
181189

182190
if __name__ == "__main__":
183191
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,15 @@ def test_j_parser_compare(self):
210210
text = st.getvalue()
211211
self.assertIn("done with distance 0", text)
212212

213+
def test_l_parser_optimize(self):
214+
output = self.get_dump_file("test_parser_optimize.onnx")
215+
st = StringIO()
216+
with redirect_stdout(st):
217+
main(["optimize", "default", self.dummy_path, "-o", output, "-v", "1"])
218+
text = st.getvalue()
219+
self.assertIn("default", text)
220+
self.assertExists(output)
221+
213222

214223
if __name__ == "__main__":
215224
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]):
15471547
print(ObsComparePair.to_str(pair_cmp))
15481548

15491549

1550+
def get_parser_optimize() -> ArgumentParser:
1551+
parser = ArgumentParser(
1552+
prog="optimize",
1553+
formatter_class=RawTextHelpFormatter,
1554+
description=textwrap.dedent(
1555+
"""
1556+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
1557+
and replaces them by the corresponding nodes. It also does basic optimization
1558+
such as removing identity nodes or unused nodes.
1559+
"""
1560+
),
1561+
epilog=textwrap.dedent(
1562+
"""
1563+
The goal is to make the model faster.
1564+
Argument patterns defines the patterns to apply or the set of patterns.
1565+
It is possible to show statistics or to remove a particular pattern.
1566+
Here are some environment variables which can be used to trigger
1567+
these displays.
1568+
1569+
Available options algorithms, default and default+runtime:
1570+
1571+
- DROPPATTERN=<pattern1,patterns2,...>: do not apply
1572+
those patterns when optimizing a model
1573+
- DUMPPATTERNS=<folder>: dumps all matched and applied
1574+
nodes when a pattern is applied
1575+
- PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
1576+
patterns to understand why one pattern was not applied,
1577+
this shows which line is rejecting a pattern if it seems one pattern was missed
1578+
"""
1579+
),
1580+
)
1581+
parser.add_argument(
1582+
"algorithm",
1583+
choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
1584+
help="algorithm or patterns optimization to apply",
1585+
)
1586+
parser.add_argument("input", type=str, help="onnx model to optimize")
1587+
parser.add_argument(
1588+
"-o",
1589+
"--output",
1590+
type=str,
1591+
required=False,
1592+
help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
1593+
)
1594+
parser.add_argument(
1595+
"-v",
1596+
"--verbose",
1597+
default=0,
1598+
required=False,
1599+
type=int,
1600+
help="verbosity",
1601+
)
1602+
parser.add_argument(
1603+
"--infer-shapes",
1604+
default=True,
1605+
action=BooleanOptionalAction,
1606+
help="infer shapes before optimizing the model",
1607+
)
1608+
parser.add_argument(
1609+
"--processor",
1610+
default="",
1611+
help=textwrap.dedent(
1612+
"""
1613+
optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
1614+
some operators are only available in one processor, it might be not used
1615+
with all
1616+
"""
1617+
).strip("\n"),
1618+
)
1619+
parser.add_argument(
1620+
"--remove-shape-info",
1621+
default=True,
1622+
action=BooleanOptionalAction,
1623+
help="remove shape information before outputting the model",
1624+
)
1625+
return parser
1626+
1627+
1628+
def _cmd_optimize(argv: List[Any]):
1629+
parser = get_parser_optimize()
1630+
args = parser.parse_args(argv[1:])
1631+
1632+
from .helpers.optim_helper import optimize_model
1633+
1634+
output = (
1635+
args.output
1636+
if args.output
1637+
else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
1638+
)
1639+
1640+
optimize_model(
1641+
args.algorithm,
1642+
args.input,
1643+
output=output,
1644+
verbose=args.verbose,
1645+
processor=args.processor,
1646+
infer_shapes=args.infer_shapes,
1647+
remove_shape_info=args.remove_shape_info,
1648+
)
1649+
1650+
15501651
#############
15511652
# main parser
15521653
#############
@@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
15631664
to get help for a specific command.
15641665
15651666
agg - aggregates statistics from multiple files
1566-
config - prints a configuration for a model id
1667+
config - prints a configuration for a model id (on HuggingFace Hub)
15671668
dot - converts an onnx model into dot format
15681669
exportsample - produces a code to export a model
15691670
find - find node consuming or producing a result
1570-
lighten - makes an onnx model lighter by removing the weights,
1671+
lighten - makes an onnx model lighter by removing the weights
1672+
optimize - optimizes an onnx model
15711673
print - prints the model on standard output
15721674
sbs - compares an exported program and a onnx model
15731675
stats - produces statistics on a model
15741676
unlighten - restores an onnx model produces by the previous experiment
1575-
validate - validate a model
1677+
validate - validate a model (knowing its model id on HuggginFace Hub)
15761678
"""
15771679
),
15781680
)
@@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
15851687
"exportsample",
15861688
"find",
15871689
"lighten",
1690+
"optimize",
15881691
"print",
15891692
"sbs",
15901693
"stats",
@@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None):
16051708
exportsample=_cmd_export_sample,
16061709
find=_cmd_find,
16071710
lighten=_cmd_lighten,
1711+
optimize=_cmd_optimize,
16081712
print=_cmd_print,
16091713
sbs=_cmd_sbs,
16101714
stats=_cmd_stats,
@@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None):
16311735
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
16321736
find=get_parser_find,
16331737
lighten=get_parser_lighten,
1738+
optimize=get_parser_optimize,
16341739
print=get_parser_print,
16351740
sbs=get_parser_sbs,
16361741
stats=get_parser_stats,
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Optional, Union
2+
import pprint
3+
import onnx
4+
5+
6+
def optimize_model(
7+
algorithm: str,
8+
model: Union[onnx.ModelProto, str],
9+
output: Optional[str] = None,
10+
processor: Optional[str] = None,
11+
infer_shapes: bool = True,
12+
remove_shape_info: bool = False,
13+
verbose: int = 1,
14+
):
15+
"""
16+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
17+
and replaces them by the corresponding nodes. It also does basic optimization
18+
such as removing identity nodes or unused nodes.
19+
20+
:param algorithm: algorithm to choose
21+
:param model: model to optimize as a proto or a filename
22+
:param output: if not empty, the optimized model is saved
23+
:param processor: optimization are done for the processor
24+
:param infer_shapes: infer shapes before optimizing, this might not be
25+
available for all algorithm
26+
:param remove_shape_info: remove shape information before saving the model
27+
:param verbose: verbosity level
28+
:return: optimized model
29+
30+
The goal is to make the model faster.
31+
Argument patterns defines the patterns to apply or the set of patterns.
32+
It is possible to show statistics or to remove a particular pattern.
33+
Here are some environment variables which can be used to trigger
34+
these displays.
35+
36+
Available options algorithms, default and default+runtime:
37+
38+
- ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
39+
those patterns when optimizing a model
40+
- ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
41+
- ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
42+
for specific patterns to understand why one pattern was not applied,
43+
this shows which line is rejecting a pattern if it seems one pattern was missed
44+
"""
45+
if isinstance(model, str):
46+
if verbose:
47+
print(f"[optimize_model] load {model!r}")
48+
proto = onnx.load(model)
49+
if verbose:
50+
print("[optimize_model] done loading.")
51+
else:
52+
proto = model
53+
54+
if verbose:
55+
print(f"[optimize_model] optimize with {algorithm!r}")
56+
if algorithm in {"default", "default+onnxruntime"}:
57+
from experimental_experiment.xoptim import get_pattern_list
58+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
59+
60+
pats = get_pattern_list(algorithm)
61+
62+
gr = GraphBuilder(
63+
proto,
64+
infer_shapes_options=infer_shapes,
65+
optimization_options=OptimizationOptions(
66+
patterns=pats,
67+
verbose=verbose,
68+
remove_unused=True,
69+
constant_folding=True,
70+
remove_identity=True,
71+
max_iter=max(100, len(proto.graph.node) // 2),
72+
processor=processor or "CPU",
73+
),
74+
)
75+
if verbose:
76+
print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
77+
print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
78+
opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
79+
if verbose:
80+
print("[optimize_model] optimization report")
81+
pprint.pprint(report)
82+
print("[optimize_model] done")
83+
84+
elif algorithm == "slim":
85+
import onnxslim
86+
87+
opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
88+
elif algorithm in {"ir", "os_ort"}:
89+
import onnx_ir
90+
import onnxscript.optimizer
91+
from onnxscript.rewriter.ort_fusions import optimize_for_ort
92+
93+
model_ir = onnx_ir.from_proto(proto)
94+
if algorithm == "ir":
95+
onnxscript.optimizer.optimize(model_ir)
96+
else:
97+
optimize_for_ort(model_ir)
98+
opt_onx = onnx_ir.serde.serialize_model(model_ir)
99+
100+
del proto
101+
if verbose:
102+
print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
103+
if remove_shape_info:
104+
if verbose:
105+
print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
106+
del opt_onx.graph.value_info[:]
107+
if verbose:
108+
print("[optimize_model] done removing shape info")
109+
110+
if output:
111+
if verbose:
112+
print(f"[optimize_model] save file into {output!r}")
113+
onnx.save(opt_onx, output, save_as_external_data=True)
114+
if verbose:
115+
print("[optimize_model] done saving")
116+
return opt_onx

0 commit comments

Comments
 (0)