Skip to content

Commit f4fea2c

Browse files
committed
improve rendering
1 parent de538cc commit f4fea2c

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

_unittests/ut_helpers/test_dot_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
import onnx
44
import onnx.helper as oh
5-
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
66
from onnx_diagnostic.helpers.dot_helper import to_dot
77
from onnx_diagnostic.export.api import to_onnx
88
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -62,6 +62,7 @@ def test_custom_doc_kernels_layer_normalization(self):
6262
self.maxDiff = None
6363
self.assertEqual(expected.strip("\n "), dot.strip("\n "))
6464

65+
@requires_transformers("4.57")
6566
def test_dot_plot_tiny(self):
6667
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
6768
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
2424
return hidden
2525

2626

27-
def _make_node_label(node: onnx.NodeProto) -> str:
27+
def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
2828
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
29-
ee = ["." if i else "" for i in node.input]
29+
ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
3030
for att in node.attribute:
3131
if att.name == "to":
3232
ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
33-
elif att.name in {"to", "axis", "value_int", "stash_type"}:
33+
elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}:
3434
ee.append(f"{att.name}={att.i}")
3535
elif att.name in {"value_float"}:
3636
ee.append(f"{att.name}={att.f}")
@@ -115,9 +115,12 @@ def _mkn(obj: object) -> int:
115115
model = onnx.shape_inference.infer_shapes(model)
116116

117117
op_type_colors = {
118-
"Shape": "#eeeeee",
118+
"Shape": "#d2a81f",
119119
"MatMul": "#ee9999",
120120
"Transpose": "#ee99ee",
121+
"Reshape": "#eeeeee",
122+
"Squeeze": "#eeeeee",
123+
"Unsqueeze": "#eeeeee",
121124
}
122125

123126
edge_label = {}
@@ -137,6 +140,7 @@ def _mkn(obj: object) -> int:
137140
outputs = list(model.graph.output)
138141
nodes = list(model.graph.node)
139142
inits = list(model.graph.initializer)
143+
tiny_inits = {}
140144
name_to_ids = {}
141145
for inp in inputs:
142146
if not inp.name:
@@ -150,17 +154,19 @@ def _mkn(obj: object) -> int:
150154
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
151155
a = onh.to_array(init)
152156
vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]"
157+
tiny_inits[init.name] = (
158+
str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
159+
)
153160
else:
154-
vals = ""
155-
ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
156-
rows.append(
157-
f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];'
158-
)
159-
name_to_ids[init.name] = f"i_{_mkn(init)}"
160-
edge_label[init.name] = ls
161+
ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
162+
rows.append(
163+
f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];'
164+
)
165+
name_to_ids[init.name] = f"i_{_mkn(init)}"
166+
edge_label[init.name] = ls
161167
for node in nodes:
162168
color = op_type_colors.get(node.op_type, "#cccccc")
163-
label = _make_node_label(node)
169+
label = _make_node_label(node, tiny_inits)
164170
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
165171
name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
166172

@@ -169,7 +175,7 @@ def _mkn(obj: object) -> int:
169175
for node in nodes:
170176
names = list(node.input)
171177
for i in names:
172-
if not i:
178+
if not i or i in tiny_inits:
173179
continue
174180
if i not in name_to_ids:
175181
raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")

0 commit comments

Comments
 (0)