Skip to content

Commit 03a2587

Browse files
authored
improves dot rendering (#334)
* improves dot * improves dot rendering
1 parent 2bc3c09 commit 03a2587

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

_unittests/ut_helpers/test_dot_helper.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,63 @@ def test_custom_doc_kernels_layer_normalization(self):
6262
self.maxDiff = None
6363
self.assertEqual(expected.strip("\n "), dot.strip("\n "))
6464

65+
def test_custom_doc_kernels_layer_normalization_constant(self):
66+
TFLOAT16 = onnx.TensorProto.FLOAT16
67+
model = oh.make_model(
68+
oh.make_graph(
69+
[
70+
oh.make_node(
71+
"LayerNormalization",
72+
["X", "W", "B"],
73+
["ln"],
74+
axis=-1,
75+
epsilon=9.999999974752427e-7,
76+
),
77+
oh.make_node("Constant", [], ["cst"], value_float=[1]),
78+
oh.make_node("Cast", ["cst"], ["cst16"], to=onnx.TensorProto.FLOAT16),
79+
oh.make_node("Add", ["ln", "cst16"], ["Z"], axis=-1),
80+
],
81+
"dummy",
82+
[
83+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
84+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
85+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
86+
],
87+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
88+
),
89+
ir_version=9,
90+
opset_imports=[oh.make_opsetid("", 18)],
91+
)
92+
dot = to_dot(model)
93+
expected = (
94+
textwrap.dedent(
95+
"""
96+
digraph {
97+
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
98+
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
99+
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
100+
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
101+
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
102+
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
103+
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
104+
Cast_4 [label="Cast([1.0], to=FLOAT16)", fillcolor="#cccccc"];
105+
Add_5[label="Add(.,.,axis=-1)",fillcolor="#cccccc"];
106+
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
107+
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
108+
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
109+
LayerNormalization_3 -> Add_5 [label="FLOAT16(b,c,d)"];
110+
Cast_4->Add_5[label="FLOAT16()"];
111+
O_6 [label="Z\\nFLOAT16(b,c,d)", fillcolor="#aaaaee"];
112+
Add_5 -> O_6;
113+
}
114+
"""
115+
)
116+
.strip("\n")
117+
.replace(" ", "")
118+
)
119+
self.maxDiff = None
120+
self.assertEqual(expected, dot.strip("\n").replace(" ", ""))
121+
65122
@requires_transformers("4.57")
66123
def test_dot_plot_tiny(self):
67124
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Dict, Set
2+
import numpy as np
23
import onnx
34
import onnx.numpy_helper as onh
5+
from ..reference import ExtendedReferenceEvaluator as Inference
46
from .onnx_helper import onnx_dtype_name, pretty_onnx
57

68

@@ -25,7 +27,7 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
2527

2628

2729
def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
28-
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
30+
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
2931
ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
3032
for att in node.attribute:
3133
if att.name == "to":
@@ -42,7 +44,10 @@ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
4244
els.append(")")
4345
if node.op_type == "Constant":
4446
els.extend([" -> ", node.output[0]])
45-
return "".join(els)
47+
res = "".join(els)
48+
if len(res) < 40:
49+
return res.replace("\\n(", "(")
50+
return res
4651

4752

4853
def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
@@ -142,14 +147,37 @@ def _mkn(obj: object) -> int:
142147
inits = list(model.graph.initializer)
143148
tiny_inits = {}
144149
name_to_ids = {}
150+
145151
for inp in inputs:
146152
if not inp.name:
147153
continue
148154
lab = _make_edge_label(inp)
149155
rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
150156
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
151157
edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
158+
159+
# Small constant --> initializer
160+
for node in nodes:
161+
if node.op_type != "Constant":
162+
continue
163+
skip = False
164+
for att in node.attribute:
165+
if att.name == "value" and (
166+
len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
167+
):
168+
skip = True
169+
break
170+
if skip:
171+
continue
172+
173+
sess = Inference(node)
174+
value = sess.run(None, {})[0]
175+
inits.append(onh.from_array(value, name=node.output[0]))
176+
152177
for init in inits:
178+
if init.name in name_to_ids:
179+
# hide optional inputs
180+
continue
153181
shape = tuple(init.dims)
154182
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155183
a = onh.to_array(init)
@@ -161,7 +189,10 @@ def _mkn(obj: object) -> int:
161189
rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
162190
name_to_ids[init.name] = f"i_{_mkn(init)}"
163191
edge_label[init.name] = ls
192+
164193
for node in nodes:
194+
if node.op_type == "Constant" and node.output[0] in tiny_inits:
195+
continue
165196
color = op_type_colors.get(node.op_type, "#cccccc")
166197
label = _make_node_label(node, tiny_inits)
167198
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')

0 commit comments

Comments
 (0)