Skip to content

Commit 1060b6d

Browse files
authored
Adds a helper to convert an onnx model into dot (#331)
* Adds a helper to convert an onnx model into dot * doc * mypy * improve dot * improve rendering * fix * json
1 parent 093c104 commit 1060b6d

File tree

11 files changed

+425
-2
lines changed

11 files changed

+425
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`331`: adds a helper to convert an onnx model into dot
78
* :pr:`330`: fixes access rope_parameters for ``transformers>=5``
89
* :pr:`329`: supports lists with OnnxruntimeEvaluator
910
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop

_doc/api/helpers/dot_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.dot_helper
3+
==================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.dot_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ onnx_diagnostic.helpers
1111
cache_helper
1212
config_helper
1313
doc_helper
14+
dot_helper
1415
fake_tensor_helper
1516
graph_helper
1617
helper

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj):
7474
start = 0
7575
end = 0
7676
subtrees = []
77-
for subspec in spec.children_specs:
77+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
7878
end += subspec.num_leaves
7979
value = subspec.unflatten(flat[start:end])
8080
value = flatten_unflatten_like_dynamic_shapes(value)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import textwrap
2+
import unittest
3+
import onnx
4+
import onnx.helper as oh
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
6+
from onnx_diagnostic.helpers.dot_helper import to_dot
7+
from onnx_diagnostic.export.api import to_onnx
8+
from onnx_diagnostic.torch_export_patches import torch_export_patches
9+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
10+
11+
12+
class TestDotHelper(ExtTestCase):
13+
def test_custom_doc_kernels_layer_normalization(self):
14+
TFLOAT16 = onnx.TensorProto.FLOAT16
15+
model = oh.make_model(
16+
oh.make_graph(
17+
[
18+
oh.make_node(
19+
"LayerNormalization",
20+
["X", "W", "B"],
21+
["ln"],
22+
axis=-1,
23+
epsilon=9.999999974752427e-7,
24+
),
25+
oh.make_node(
26+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
27+
),
28+
],
29+
"dummy",
30+
[
31+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
32+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
33+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
34+
],
35+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
36+
),
37+
ir_version=9,
38+
opset_imports=[oh.make_opsetid("", 18)],
39+
)
40+
dot = to_dot(model)
41+
expected = textwrap.dedent(
42+
"""
43+
digraph {
44+
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
45+
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
46+
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
47+
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
48+
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
49+
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
50+
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
51+
Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"];
52+
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
53+
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
54+
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
55+
LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"];
56+
I_1 -> Add_4 [label="FLOAT16(d)"];
57+
O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"];
58+
Add_4 -> O_5;
59+
}
60+
"""
61+
)
62+
self.maxDiff = None
63+
self.assertEqual(expected.strip("\n "), dot.strip("\n "))
64+
65+
@requires_transformers("4.57")
66+
def test_dot_plot_tiny(self):
67+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
68+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
69+
with torch_export_patches(patch_transformers=True):
70+
em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
71+
dot = to_dot(em.model_proto)
72+
name = self.get_dump_file("test_dot_plot_tiny.dot")
73+
with open(name, "w") as f:
74+
f.write(dot)
75+
# dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png
76+
self.assertIn("-> Add", dot)
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
get_main_parser,
77
get_parser_agg,
88
get_parser_config,
9+
get_parser_dot,
910
get_parser_find,
1011
get_parser_lighten,
1112
get_parser_print,
@@ -23,6 +24,7 @@ def test_main_parser(self):
2324
get_main_parser().print_help()
2425
text = st.getvalue()
2526
self.assertIn("lighten", text)
27+
self.assertIn("dot", text)
2628

2729
def test_parser_lighten(self):
2830
st = StringIO()
@@ -87,6 +89,13 @@ def test_parser_sbs(self):
8789
text = st.getvalue()
8890
self.assertIn("--onnx", text)
8991

92+
def test_parser_dot(self):
93+
st = StringIO()
94+
with redirect_stdout(st):
95+
get_parser_dot().print_help()
96+
text = st.getvalue()
97+
self.assertIn("--run", text)
98+
9099

91100
if __name__ == "__main__":
92101
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,45 @@ def forward(self, x):
162162
sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")]
163163
self.assertEqual(sdf.shape[0], 4)
164164

165+
@ignore_warnings(UserWarning)
166+
@requires_transformers("4.53")
167+
def test_i_parser_dot(self):
168+
import torch
169+
170+
class Model(torch.nn.Module):
171+
def __init__(self):
172+
super(Model, self).__init__()
173+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
174+
self.relu = torch.nn.ReLU()
175+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
176+
177+
def forward(self, x):
178+
x = self.relu(self.fc1(x))
179+
x = self.fc2(x)
180+
return x
181+
182+
inputs = dict(x=torch.randn((5, 10)))
183+
ds = dict(x={0: "batch"})
184+
onnx_file = self.get_dump_file("test_i_parser_dot.model.onnx")
185+
to_onnx(
186+
Model(),
187+
kwargs=inputs,
188+
dynamic_shapes=ds,
189+
exporter="custom",
190+
filename=onnx_file,
191+
)
192+
193+
output = self.get_dump_file("test_i_parser_dot.dot")
194+
args = ["dot", onnx_file, "-v", "1", "-o", output]
195+
if not self.unit_test_going():
196+
args.extend(["--run", "svg"])
197+
198+
st = StringIO()
199+
with redirect_stdout(st):
200+
main(args)
201+
text = st.getvalue()
202+
print(text)
203+
165204

166205
if __name__ == "__main__":
167206
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,77 @@
1111
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
1212

1313

14+
def get_parser_dot() -> ArgumentParser:
15+
parser = ArgumentParser(
16+
prog="dot",
17+
description=textwrap.dedent(
18+
"""
19+
Converts a model into a dot file dot can draw into a graph.
20+
"""
21+
),
22+
)
23+
parser.add_argument("input", type=str, help="onnx model to lighten")
24+
parser.add_argument(
25+
"-o",
26+
"--output",
27+
default="",
28+
type=str,
29+
required=False,
30+
help="dot model to output or empty to print out the result",
31+
)
32+
parser.add_argument(
33+
"-v",
34+
"--verbose",
35+
type=int,
36+
default=0,
37+
required=False,
38+
help="verbosity",
39+
)
40+
parser.add_argument(
41+
"-r",
42+
"--run",
43+
default="",
44+
required=False,
45+
help="run dot, in that case, format must be given (svg, png)",
46+
)
47+
return parser
48+
49+
50+
def _cmd_dot(argv: List[Any]):
51+
import subprocess
52+
from .helpers.dot_helper import to_dot
53+
54+
parser = get_parser_dot()
55+
args = parser.parse_args(argv[1:])
56+
if args.verbose:
57+
print(f"-- loads {args.input!r}")
58+
onx = onnx.load(args.input, load_external_data=False)
59+
if args.verbose:
60+
print("-- converts into dot")
61+
dot = to_dot(onx)
62+
if args.output:
63+
if args.verbose:
64+
print(f"-- saves into {args.output}")
65+
with open(args.output, "w") as f:
66+
f.write(dot)
67+
else:
68+
print(dot)
69+
if args.run:
70+
assert args.output, "Cannot run dot without an output file."
71+
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
72+
if args.verbose:
73+
print(f"-- run {' '.join(cmds)}")
74+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75+
res = p.communicate()
76+
out, err = res
77+
if out:
78+
print("--")
79+
print(out)
80+
if err:
81+
print("--")
82+
print(err)
83+
84+
1485
def get_parser_lighten() -> ArgumentParser:
1586
parser = ArgumentParser(
1687
prog="lighten",
@@ -1412,6 +1483,7 @@ def get_main_parser() -> ArgumentParser:
14121483
14131484
agg - aggregates statistics from multiple files
14141485
config - prints a configuration for a model id
1486+
dot - converts an onnx model into dot format
14151487
exportsample - produces a code to export a model
14161488
find - find node consuming or producing a result
14171489
lighten - makes an onnx model lighter by removing the weights,
@@ -1428,6 +1500,7 @@ def get_main_parser() -> ArgumentParser:
14281500
choices=[
14291501
"agg",
14301502
"config",
1503+
"dot",
14311504
"exportsample",
14321505
"find",
14331506
"lighten",
@@ -1446,6 +1519,7 @@ def main(argv: Optional[List[Any]] = None):
14461519
fcts = dict(
14471520
agg=_cmd_agg,
14481521
config=_cmd_config,
1522+
dot=_cmd_dot,
14491523
exportsample=_cmd_export_sample,
14501524
find=_cmd_find,
14511525
lighten=_cmd_lighten,
@@ -1470,6 +1544,7 @@ def main(argv: Optional[List[Any]] = None):
14701544
parsers = dict(
14711545
agg=get_parser_agg,
14721546
config=get_parser_config,
1547+
dot=get_parser_dot,
14731548
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
14741549
find=get_parser_find,
14751550
lighten=get_parser_lighten,

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
8080
start = 0
8181
end = 0
8282
subtrees = []
83-
for subspec in spec.children_specs:
83+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
8484
end += subspec.num_leaves
8585
value = subspec.unflatten(flat[start:end])
8686
value = flatten_unflatten_for_dynamic_shapes(

0 commit comments

Comments
 (0)