Skip to content

Commit de538cc

Browse files
committed
improve dot
1 parent 70526e6 commit de538cc

File tree

2 files changed

+46
-26
lines changed

2 files changed

+46
-26
lines changed

_unittests/ut_helpers/test_dot_helper.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ def test_custom_doc_kernels_layer_normalization(self):
4343
digraph {
4444
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
4545
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
46-
edge [arrowhead=vee, fontsize=6];
47-
I_0 [label="X", fillcolor="#aaeeaa"];
48-
I_1 [label="W", fillcolor="#aaeeaa"];
49-
I_2 [label="B", fillcolor="#aaeeaa"];
46+
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
47+
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
48+
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
49+
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
5050
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
5151
Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"];
52-
I_0 -> LayerNormalization_3;
53-
I_1 -> LayerNormalization_3;
54-
I_2 -> LayerNormalization_3;
52+
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
53+
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
54+
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
5555
LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"];
56-
I_1 -> Add_4;
57-
O_5 [label="Z", fillcolor="#aaaaee"];
56+
I_1 -> Add_4 [label="FLOAT16(d)"];
57+
O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"];
5858
Add_4 -> O_5;
5959
}
6060
"""

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Set
22
import onnx
3+
import onnx.numpy_helper as onh
34
from .onnx_helper import onnx_dtype_name, pretty_onnx
45

56

@@ -44,6 +45,24 @@ def _make_node_label(node: onnx.NodeProto) -> str:
4445
return "".join(els)
4546

4647

48+
def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
49+
itype = value_info.type.tensor_type.elem_type
50+
if itype == onnx.TensorProto.UNDEFINED:
51+
return ""
52+
shape = tuple(
53+
d.dim_param if d.dim_param else d.dim_value
54+
for d in value_info.type.tensor_type.shape.dim
55+
)
56+
res = [
57+
str(a)
58+
for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape]
59+
]
60+
sshape = ",".join(res)
61+
if multi_line and len(sshape) > 30:
62+
sshape = ",\\n".join(res)
63+
return f"{onnx_dtype_name(itype)}({sshape})"
64+
65+
4766
def to_dot(model: onnx.ModelProto) -> str:
4867
"""
4968
Converts a model into a dot graph.
@@ -103,19 +122,7 @@ def _mkn(obj: object) -> int:
103122

104123
edge_label = {}
105124
for val in model.graph.value_info:
106-
itype = val.type.tensor_type.elem_type
107-
if itype == onnx.TensorProto.UNDEFINED:
108-
continue
109-
shape = tuple(
110-
d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim
111-
)
112-
sshape = ",".join(
113-
map(
114-
str,
115-
[("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape],
116-
)
117-
)
118-
edge_label[val.name] = f"{onnx_dtype_name(itype)}({sshape})"
125+
edge_label[val.name] = _make_edge_label(val, multi_line=True)
119126

120127
rows = [
121128
"digraph {",
@@ -124,7 +131,7 @@ def _mkn(obj: object) -> int:
124131
"ranksep=0.2, fontsize=8];"
125132
),
126133
' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
127-
" edge [arrowhead=vee, fontsize=6];",
134+
" edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];",
128135
]
129136
inputs = list(model.graph.input)
130137
outputs = list(model.graph.output)
@@ -134,11 +141,23 @@ def _mkn(obj: object) -> int:
134141
for inp in inputs:
135142
if not inp.name:
136143
continue
137-
rows.append(f' I_{_mkn(inp)} [label="{inp.name}", fillcolor="#aaeeaa"];')
144+
lab = _make_edge_label(inp)
145+
rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
138146
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
147+
edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
139148
for init in inits:
140-
rows.append(f' i_{_mkn(init)} [label="{init.name}", fillcolor="#cccc00"];')
149+
shape = tuple(init.dims)
150+
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
151+
a = onh.to_array(init)
152+
vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]"
153+
else:
154+
vals = ""
155+
ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
156+
rows.append(
157+
f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];'
158+
)
141159
name_to_ids[init.name] = f"i_{_mkn(init)}"
160+
edge_label[init.name] = ls
142161
for node in nodes:
143162
color = op_type_colors.get(node.op_type, "#cccccc")
144163
label = _make_node_label(node)
@@ -179,7 +198,8 @@ def _mkn(obj: object) -> int:
179198
for out in outputs:
180199
if not out.name:
181200
continue
182-
rows.append(f' O_{_mkn(out)} [label="{out.name}", fillcolor="#aaaaee"];')
201+
lab = _make_edge_label(inp)
202+
rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];')
183203
edge = name_to_ids[out.name], f"O_{_mkn(out)}"
184204
rows.append(f" {edge[0]} -> {edge[1]};")
185205

0 commit comments

Comments
 (0)