Skip to content

Commit a6044a2

Browse files
committed
Adds a helper to convert an onnx model into dot
1 parent 093c104 commit a6044a2

File tree

7 files changed

+277
-2
lines changed

7 files changed

+277
-2
lines changed

_doc/api/helpers/dot_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.dot_helper
3+
==================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.dot_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ onnx_diagnostic.helpers
1111
cache_helper
1212
config_helper
1313
doc_helper
14+
dot_helper
1415
fake_tensor_helper
1516
graph_helper
1617
helper

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj):
7474
start = 0
7575
end = 0
7676
subtrees = []
77-
for subspec in spec.children_specs:
77+
for subspec in spec.children():
7878
end += subspec.num_leaves
7979
value = subspec.unflatten(flat[start:end])
8080
value = flatten_unflatten_like_dynamic_shapes(value)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import textwrap
2+
import unittest
3+
import onnx
4+
import onnx.helper as oh
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.helpers.dot_helper import to_dot
7+
from onnx_diagnostic.export.api import to_onnx
8+
from onnx_diagnostic.torch_export_patches import torch_export_patches
9+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
10+
11+
12+
class TestDotHelper(ExtTestCase):
13+
def test_custom_doc_kernels_layer_normalization(self):
14+
TFLOAT16 = onnx.TensorProto.FLOAT16
15+
model = oh.make_model(
16+
oh.make_graph(
17+
[
18+
oh.make_node(
19+
"LayerNormalization",
20+
["X", "W", "B"],
21+
["ln"],
22+
axis=-1,
23+
epsilon=9.999999974752427e-7,
24+
),
25+
oh.make_node(
26+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
27+
),
28+
],
29+
"dummy",
30+
[
31+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
32+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
33+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
34+
],
35+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
36+
),
37+
ir_version=9,
38+
opset_imports=[oh.make_opsetid("", 18)],
39+
)
40+
dot = to_dot(model)
41+
expected = textwrap.dedent(
42+
"""
43+
digraph {
44+
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
45+
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
46+
edge [arrowhead=vee, fontsize=6];
47+
I_0 [label="X", fillcolor="#aaeeaa"];
48+
I_1 [label="W", fillcolor="#aaeeaa"];
49+
I_2 [label="B", fillcolor="#aaeeaa"];
50+
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
51+
Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"];
52+
I_0 -> LayerNormalization_3;
53+
I_1 -> LayerNormalization_3;
54+
I_2 -> LayerNormalization_3;
55+
LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"];
56+
I_1 -> Add_4;
57+
O_5 [label="Z", fillcolor="#aaaaee"];
58+
Add_4 -> O_5;
59+
}
60+
"""
61+
)
62+
self.maxDiff = None
63+
self.assertEqual(expected.strip("\n "), dot.strip("\n "))
64+
65+
def test_dot_plot_tiny(self):
66+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
67+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
68+
with torch_export_patches(patch_transformers=True):
69+
em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
70+
dot = to_dot(em.model_proto)
71+
name = self.get_dump_file("test_dot_plot_tiny.dot")
72+
with open(name, "w") as f:
73+
f.write(dot)
74+
# dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png
75+
self.assertIn("-> Add", dot)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main(verbosity=2)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
8080
start = 0
8181
end = 0
8282
subtrees = []
83-
for subspec in spec.children_specs:
83+
for subspec in spec.children():
8484
end += subspec.num_leaves
8585
value = subspec.unflatten(flat[start:end])
8686
value = flatten_unflatten_for_dynamic_shapes(
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from typing import Set
2+
import onnx
3+
from .onnx_helper import onnx_dtype_name, pretty_onnx
4+
5+
6+
def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
7+
hidden = set()
8+
memo = (
9+
{i.name for i in graph.initializer}
10+
| {i.values.name for i in graph.sparse_initializer}
11+
| {i.name for i in graph.input}
12+
)
13+
for node in graph.node:
14+
for i in node.input:
15+
if i not in memo:
16+
hidden.add(i)
17+
for att in node.attribute:
18+
if att.type == onnx.AttributeProto.GRAPH and att.g:
19+
hid = _get_hidden_inputs(att.g)
20+
less = set(h for h in hid if h not in memo)
21+
hidden |= less
22+
memo |= set(node.output)
23+
return hidden
24+
25+
26+
def _make_node_label(node: onnx.NodeProto) -> str:
27+
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
28+
ee = ["." if i else "" for i in node.input]
29+
for att in node.attribute:
30+
if att.name == "to":
31+
ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
32+
elif att.name in {"to", "axis", "value_int", "stash_type"}:
33+
ee.append(f"{att.name}={att.i}")
34+
elif att.name in {"value_float"}:
35+
ee.append(f"{att.name}={att.f}")
36+
elif att.name in {"value_floats"}:
37+
ee.append(f"{att.name}={att.floats}")
38+
elif att.name in {"value_ints", "perm"}:
39+
ee.append(f"{att.name}={att.ints}")
40+
els.append(", ".join(ee))
41+
els.append(")")
42+
if node.op_type == "Constant":
43+
els.extend([" -> ", node.output[0]])
44+
return "".join(els)
45+
46+
47+
def to_dot(model: onnx.ModelProto) -> str:
48+
"""
49+
Converts a model into a dot graph.
50+
Here is an example:
51+
52+
.. gdot::
53+
:script: DOT-SECTION
54+
:process:
55+
56+
from onnx_diagnostic.helpers.dot_helper import to_dot
57+
from onnx_diagnostic.export.api import to_onnx
58+
from onnx_diagnostic.torch_export_patches import torch_export_patches
59+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
60+
61+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
62+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
63+
with torch_export_patches(patch_transformers=True):
64+
em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
65+
dot = to_dot(em.model_proto)
66+
print("DOT-SECTION", dot)
67+
68+
Or this one obtained with :func:`torch.onnx.export`.
69+
70+
.. gdot::
71+
:script: DOT-SECTION
72+
:process:
73+
74+
from onnx_diagnostic.helpers.dot_helper import to_dot
75+
from onnx_diagnostic.export.api import to_onnx
76+
from onnx_diagnostic.torch_export_patches import torch_export_patches
77+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
78+
79+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
80+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
81+
with torch_export_patches(patch_transformers=True):
82+
em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
83+
dot = to_dot(em.model_proto)
84+
print("DOT-SECTION", dot)
85+
"""
86+
_unique = {}
87+
88+
def _mkn(obj: object) -> int:
89+
id_obj = id(obj)
90+
if id_obj in _unique:
91+
return _unique[id_obj]
92+
i = len(_unique)
93+
_unique[id_obj] = i
94+
return i
95+
96+
model = onnx.shape_inference.infer_shapes(model)
97+
98+
op_type_colors = {
99+
"Shape": "#eeeeee",
100+
"MatMul": "#ee9999",
101+
"Transpose": "#ee99ee",
102+
}
103+
104+
edge_label = {}
105+
for val in model.graph.value_info:
106+
itype = val.type.tensor_type.elem_type
107+
if itype == onnx.TensorProto.UNDEFINED:
108+
continue
109+
shape = tuple(
110+
d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim
111+
)
112+
sshape = ",".join(
113+
map(
114+
str,
115+
[("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape],
116+
)
117+
)
118+
edge_label[val.name] = f"{onnx_dtype_name(itype)}({sshape})"
119+
120+
rows = [
121+
"digraph {",
122+
(
123+
" graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, "
124+
"ranksep=0.2, fontsize=8];"
125+
),
126+
' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
127+
" edge [arrowhead=vee, fontsize=6];",
128+
]
129+
inputs = list(model.graph.input)
130+
outputs = list(model.graph.output)
131+
nodes = list(model.graph.node)
132+
inits = list(model.graph.initializer)
133+
name_to_ids = {}
134+
for inp in inputs:
135+
if not inp.name:
136+
continue
137+
rows.append(f' I_{_mkn(inp)} [label="{inp.name}", fillcolor="#aaeeaa"];')
138+
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
139+
for init in inits:
140+
rows.append(f' i_{_mkn(init)} [label="{init.name}", fillcolor="#cccc00"];')
141+
name_to_ids[init.name] = f"i_{_mkn(init)}"
142+
for node in nodes:
143+
color = op_type_colors.get(node.op_type, "#cccccc")
144+
label = _make_node_label(node)
145+
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
146+
name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
147+
148+
# nodes
149+
done = set()
150+
for node in nodes:
151+
names = list(node.input)
152+
for i in names:
153+
if not i:
154+
continue
155+
if i not in name_to_ids:
156+
raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")
157+
edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}"
158+
if edge in done:
159+
continue
160+
done.add(edge)
161+
lab = edge_label.get(i, "")
162+
if lab:
163+
ls = ",".join([f'label="{lab}"'])
164+
lab = f" [{ls}]"
165+
rows.append(f" {edge[0]} -> {edge[1]}{lab};")
166+
if node.op_type in {"Scan", "Loop", "If"}:
167+
unique = set()
168+
for att in node.attribute:
169+
if att.type == onnx.AttributeProto.GRAPH:
170+
unique |= _get_hidden_inputs(att.g)
171+
for i in unique:
172+
edge = name_to_ids[i], _mkn(node)
173+
if edge in done:
174+
continue
175+
done.add(edge)
176+
rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];")
177+
178+
# outputs
179+
for out in outputs:
180+
if not out.name:
181+
continue
182+
rows.append(f' O_{_mkn(out)} [label="{out.name}", fillcolor="#aaaaee"];')
183+
edge = name_to_ids[out.name], f"O_{_mkn(out)}"
184+
rows.append(f" {edge[0]} -> {edge[1]};")
185+
186+
rows.append("}")
187+
return "\n".join(rows)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ select = [
172172
"_scripts/compare_model_execution.py" = ["E402", "F401"]
173173
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"]
174174
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
175+
"_unittests/ut_helpers/test_dot_helper.py" = ["E501"]
175176
"_unittests/ut_tasks/try_export.py" = ["B008", "B904", "E501", "PIE808", "SIM117", "SIM105", "UP008"]
176177
"onnx_diagnostic/export/__init__.py" = ["F401"]
177178
"onnx_diagnostic/helpers/__init__.py" = ["F401"]

0 commit comments

Comments
 (0)