Skip to content

Commit c6ddf5d

Browse files
committed
improves dot
1 parent 2bc3c09 commit c6ddf5d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Dict, Set
2+
import numpy as np
23
import onnx
34
import onnx.numpy_helper as onh
5+
from ..reference import ExtendedReferenceEvaluator as Inference
46
from .onnx_helper import onnx_dtype_name, pretty_onnx
57

68

@@ -142,14 +144,37 @@ def _mkn(obj: object) -> int:
142144
inits = list(model.graph.initializer)
143145
tiny_inits = {}
144146
name_to_ids = {}
147+
145148
for inp in inputs:
146149
if not inp.name:
147150
continue
148151
lab = _make_edge_label(inp)
149152
rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
150153
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
151154
edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
155+
156+
# Small constant --> initializer
157+
for node in nodes:
158+
if node.op_type != "Constant":
159+
continue
160+
skip = False
161+
for att in node.attribute:
162+
if att.name == "value" and (
163+
len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
164+
):
165+
skip = True
166+
break
167+
if skip:
168+
continue
169+
170+
sess = Inference(node)
171+
value = sess.run(None, {})[0]
172+
inits.append(onh.from_array(value, name=node.output[0]))
173+
152174
for init in inits:
175+
if init.name in inputs:
176+
# hide optional inputs
177+
continue
153178
shape = tuple(init.dims)
154179
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155180
a = onh.to_array(init)
@@ -161,7 +186,10 @@ def _mkn(obj: object) -> int:
161186
rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
162187
name_to_ids[init.name] = f"i_{_mkn(init)}"
163188
edge_label[init.name] = ls
189+
164190
for node in nodes:
191+
if node.op_type == "Constant" and node.output[0] in name_to_ids:
192+
continue
165193
color = op_type_colors.get(node.op_type, "#cccccc")
166194
label = _make_node_label(node, tiny_inits)
167195
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')

0 commit comments

Comments
 (0)