Skip to content

Commit c0e2e46

Browse files
committed
improves dot rendering
1 parent c6ddf5d commit c0e2e46

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

_unittests/ut_helpers/test_dot_helper.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,63 @@ def test_custom_doc_kernels_layer_normalization(self):
6262
self.maxDiff = None
6363
self.assertEqual(expected.strip("\n "), dot.strip("\n "))
6464

65+
def test_custom_doc_kernels_layer_normalization_constant(self):
66+
TFLOAT16 = onnx.TensorProto.FLOAT16
67+
model = oh.make_model(
68+
oh.make_graph(
69+
[
70+
oh.make_node(
71+
"LayerNormalization",
72+
["X", "W", "B"],
73+
["ln"],
74+
axis=-1,
75+
epsilon=9.999999974752427e-7,
76+
),
77+
oh.make_node("Constant", [], ["cst"], value_float=[1]),
78+
oh.make_node("Cast", ["cst"], ["cst16"], to=onnx.TensorProto.FLOAT16),
79+
oh.make_node("Add", ["ln", "cst16"], ["Z"], axis=-1),
80+
],
81+
"dummy",
82+
[
83+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
84+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
85+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
86+
],
87+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
88+
),
89+
ir_version=9,
90+
opset_imports=[oh.make_opsetid("", 18)],
91+
)
92+
dot = to_dot(model)
93+
expected = (
94+
textwrap.dedent(
95+
"""
96+
digraph {
97+
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
98+
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
99+
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
100+
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
101+
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
102+
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
103+
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
104+
Cast_4 [label="Cast([1.0], to=FLOAT16)", fillcolor="#cccccc"];
105+
Add_5[label="Add(.,.,axis=-1)",fillcolor="#cccccc"];
106+
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
107+
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
108+
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
109+
LayerNormalization_3 -> Add_5 [label="FLOAT16(b,c,d)"];
110+
Cast_4->Add_5[label="FLOAT16()"];
111+
O_6 [label="Z\\nFLOAT16(b,c,d)", fillcolor="#aaaaee"];
112+
Add_5 -> O_6;
113+
}
114+
"""
115+
)
116+
.strip("\n")
117+
.replace(" ", "")
118+
)
119+
self.maxDiff = None
120+
self.assertEqual(expected, dot.strip("\n").replace(" ", ""))
121+
65122
@requires_transformers("4.57")
66123
def test_dot_plot_tiny(self):
67124
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
2727

2828

2929
def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
30-
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
30+
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
3131
ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
3232
for att in node.attribute:
3333
if att.name == "to":
@@ -44,7 +44,10 @@ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
4444
els.append(")")
4545
if node.op_type == "Constant":
4646
els.extend([" -> ", node.output[0]])
47-
return "".join(els)
47+
res = "".join(els)
48+
if len(res) < 40:
49+
return res.replace("\\n(", "(")
50+
return res
4851

4952

5053
def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
@@ -172,7 +175,7 @@ def _mkn(obj: object) -> int:
172175
inits.append(onh.from_array(value, name=node.output[0]))
173176

174177
for init in inits:
175-
if init.name in inputs:
178+
if init.name in name_to_ids:
176179
# hide optional inputs
177180
continue
178181
shape = tuple(init.dims)
@@ -188,7 +191,7 @@ def _mkn(obj: object) -> int:
188191
edge_label[init.name] = ls
189192

190193
for node in nodes:
191-
if node.op_type == "Constant" and node.output[0] in name_to_ids:
194+
if node.op_type == "Constant" and node.output[0] in tiny_inits:
192195
continue
193196
color = op_type_colors.get(node.op_type, "#cccccc")
194197
label = _make_node_label(node, tiny_inits)

0 commit comments

Comments
 (0)