Skip to content

Commit 4d9b611

Browse files
committed
rendering
1 parent f909dd6 commit 4d9b611

File tree

5 files changed

+431
-1
lines changed

5 files changed

+431
-1
lines changed

_doc/api/helpers/graph_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.graph_helper
3+
====================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.graph_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ onnx_diagnostic.helpers
1010
bench_run
1111
cache_helper
1212
config_helper
13+
graph_helper
1314
helper
1415
memory_peak
1516
mini_onnx_builder
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import textwrap
2+
import unittest
3+
import onnx
4+
import onnx.helper as oh
5+
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.helpers.graph_helper import GraphRendering
7+
8+
TFLOAT = onnx.TensorProto.FLOAT
9+
10+
11+
class TestGraphHelper(ExtTestCase):
12+
def test_computation_order(self):
13+
proto = oh.make_model(
14+
oh.make_graph(
15+
[
16+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
17+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
18+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
19+
],
20+
"-nd-",
21+
[
22+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
23+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
24+
],
25+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
26+
),
27+
opset_imports=[oh.make_opsetid("", 18)],
28+
ir_version=9,
29+
)
30+
order = GraphRendering.computation_order(
31+
proto.graph.node, [i.name for i in [*proto.graph.input, *proto.graph.initializer]]
32+
)
33+
self.assertEqual([1, 2, 3], order)
34+
35+
def test_graph_positions1(self):
36+
proto = oh.make_model(
37+
oh.make_graph(
38+
[
39+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
40+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
41+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
42+
],
43+
"-nd-",
44+
[
45+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
46+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
47+
],
48+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
49+
),
50+
opset_imports=[oh.make_opsetid("", 18)],
51+
ir_version=9,
52+
)
53+
existing = [i.name for i in [*proto.graph.input, *proto.graph.initializer]]
54+
order = GraphRendering.computation_order(proto.graph.node, existing)
55+
positions = GraphRendering.graph_positions(proto.graph.node, order, existing)
56+
self.assertEqual([(1, 0), (2, 0), (3, 0)], positions)
57+
58+
def test_graph_positions2(self):
59+
proto = oh.make_model(
60+
oh.make_graph(
61+
[
62+
oh.make_node("Add", ["X", "Y"], ["xy"]),
63+
oh.make_node("Neg", ["Y"], ["ny"]),
64+
oh.make_node("Mul", ["xy", "ny"], ["Z"]),
65+
],
66+
"-nd-",
67+
[
68+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
69+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
70+
],
71+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
72+
),
73+
opset_imports=[oh.make_opsetid("", 18)],
74+
ir_version=9,
75+
)
76+
existing = [i.name for i in [*proto.graph.input, *proto.graph.initializer]]
77+
order = GraphRendering.computation_order(proto.graph.node, existing)
78+
positions = GraphRendering.graph_positions(proto.graph.node, order, existing)
79+
self.assertEqual([(1, 0), (1, 1), (2, 0)], positions)
80+
81+
def test_text_positionss(self):
82+
proto = oh.make_model(
83+
oh.make_graph(
84+
[
85+
oh.make_node("Add", ["X", "Y"], ["xy"]),
86+
oh.make_node("Neg", ["Y"], ["ny"]),
87+
oh.make_node("Mul", ["xy", "ny"], ["a"]),
88+
oh.make_node("Mul", ["a", "Y"], ["Z"]),
89+
],
90+
"-nd-",
91+
[
92+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
93+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
94+
],
95+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
96+
),
97+
opset_imports=[oh.make_opsetid("", 18)],
98+
ir_version=9,
99+
)
100+
existing = [i.name for i in [*proto.graph.input, *proto.graph.initializer]]
101+
order = GraphRendering.computation_order(proto.graph.node, existing)
102+
self.assertEqual([1, 1, 2, 3], order)
103+
positions = GraphRendering.graph_positions(proto.graph.node, order, existing)
104+
self.assertEqual([(1, 0), (1, 1), (2, 0), (3, 0)], positions)
105+
text_pos = GraphRendering.text_positions(proto.graph.node, positions)
106+
self.assertEqual([(4, 4), (4, 20), (8, 12), (12, 20)], text_pos)
107+
108+
def test_text_rendering(self):
109+
proto = oh.make_model(
110+
oh.make_graph(
111+
[
112+
oh.make_node("Add", ["X", "Y"], ["xy"]),
113+
oh.make_node("Neg", ["Y"], ["ny"]),
114+
oh.make_node("Mul", ["xy", "ny"], ["a"]),
115+
oh.make_node("Mul", ["a", "Y"], ["Z"]),
116+
],
117+
"-nd-",
118+
[
119+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
120+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
121+
],
122+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
123+
),
124+
opset_imports=[oh.make_opsetid("", 18)],
125+
ir_version=9,
126+
)
127+
graph = GraphRendering(proto)
128+
text = textwrap.dedent(graph.text_rendering(prefix="|")).strip("\n")
129+
expected = textwrap.dedent(
130+
"""
131+
|
132+
|
133+
|
134+
|
135+
| Add Neg
136+
| | |
137+
| +-------+-------+
138+
| |
139+
| Mul
140+
| |
141+
| +-------+
142+
| |
143+
| Mul
144+
|
145+
|
146+
"""
147+
).strip("\n")
148+
self.assertEqual(expected, text)
149+
150+
151+
if __name__ == "__main__":
152+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)