Skip to content

Commit 66f77c2

Browse files
authored
Improves graph rendering (#110)
* Adds input/output to the text rendering * mypy * improve graph rendering
1 parent 0a4d2c1 commit 66f77c2

File tree

4 files changed

+146
-50
lines changed

4 files changed

+146
-50
lines changed

_unittests/ut_helpers/test_graph_helper.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import textwrap
22
import unittest
3+
import numpy as np
34
import onnx
45
import onnx.helper as oh
6+
import onnx.numpy_helper as onh
57
from onnx_diagnostic.ext_test_case import ExtTestCase
68
from onnx_diagnostic.helpers.graph_helper import GraphRendering
79

@@ -126,36 +128,35 @@ def test_text_rendering(self):
126128
)
127129
graph = GraphRendering(proto)
128130
text = textwrap.dedent(graph.text_rendering(prefix="|")).strip("\n")
129-
print()
130-
print(text)
131-
expected = textwrap.dedent(
132-
"""
133-
|
134-
|
135-
|
136-
|
137-
| X Y
138-
| | |
139-
| +------+-----+-----------+
140-
| | |
141-
| Add Neg
142-
| | |
143-
| +-----+-----------+
144-
| |
145-
| Mul
146-
| |
147-
| +-----------+
148-
| |
149-
| Mul
150-
| |
151-
| +----------------------+
152-
| |
153-
| Z
154-
|
155-
|
156-
"""
157-
).strip("\n")
158-
self.assertEqual(expected, text)
131+
self.assertIn("└-------┬---┼---┘", text)
132+
133+
def test_text_rendering_more(self):
134+
proto = oh.make_model(
135+
oh.make_graph(
136+
[
137+
oh.make_node("Add", ["X", "Y"], ["xy"]),
138+
oh.make_node("Neg", ["Y"], ["ny"]),
139+
oh.make_node("Mul", ["xy", "ny"], ["a"]),
140+
oh.make_node("Div", ["xy", "two"], ["b"]),
141+
oh.make_node("Add", ["b", "Y"], ["by"]),
142+
oh.make_node("Mod", ["a", "ny"], ["ay"]),
143+
oh.make_node("Sub", ["ay", "by"], ["Z"]),
144+
],
145+
"-nd-",
146+
[
147+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
148+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
149+
],
150+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
151+
[onh.from_array(np.array([2], dtype=np.float32), name="two")],
152+
),
153+
opset_imports=[oh.make_opsetid("", 18)],
154+
ir_version=9,
155+
)
156+
onnx.checker.check_model(proto)
157+
graph = GraphRendering(proto)
158+
text = textwrap.dedent(graph.text_rendering(prefix="|")).strip("\n")
159+
self.assertIn(" └-------┬---┼---┴-------┐", text)
159160

160161

161162
if __name__ == "__main__":

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ def dummy_path(self):
1414
)
1515

1616
def test_parser_print(self):
17-
st = StringIO()
18-
with redirect_stdout(st):
19-
main(["print", "raw", self.dummy_path])
20-
text = st.getvalue()
21-
self.assertIn("Add", text)
17+
for fmt in ["raw", "text", "pretty", "printer"]:
18+
with self.subTest(format=fmt):
19+
st = StringIO()
20+
with redirect_stdout(st):
21+
main(["print", fmt, self.dummy_path])
22+
text = st.getvalue()
23+
self.assertIn("Add", text)
2224

2325
def test_parser_stats(self):
2426
output = self.get_dump_file("test_parser_stats.xlsx")

onnx_diagnostic/_command_lines_parser.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,23 @@ def get_parser_print() -> ArgumentParser:
126126
"""
127127
),
128128
epilog="To show a model.",
129+
formatter_class=RawTextHelpFormatter,
129130
)
130131
parser.add_argument(
131-
"fmt", choices=["pretty", "raw"], help="Format to use.", default="pretty"
132+
"fmt",
133+
choices=["pretty", "raw", "text", "printer"],
134+
default="pretty",
135+
help=textwrap.dedent(
136+
"""
137+
Prints out a model on the standard output.
138+
raw - just prints the model with print(...)
139+
printer - onnx.printer.to_text(...)
140+
pretty - an improved rendering
141+
text - uses GraphRendering
142+
""".strip(
143+
"\n"
144+
)
145+
),
132146
)
133147
parser.add_argument("input", type=str, help="onnx model to load")
134148
return parser
@@ -144,6 +158,12 @@ def _cmd_print(argv: List[Any]):
144158
from .helpers.onnx_helper import pretty_onnx
145159

146160
print(pretty_onnx(onx))
161+
elif args.fmt == "printer":
162+
print(onnx.printer.to_text(onx))
163+
elif args.fmt == "text":
164+
from .helpers.graph_helper import GraphRendering
165+
166+
print(GraphRendering(onx).text_rendering())
147167
else:
148168
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
149169

onnx_diagnostic/helpers/graph_helper.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pprint
12
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
23
import onnx
34
import onnx.helper as oh
@@ -45,8 +46,12 @@ def computation_order(
4546
f"Missing input in node {i_node} type={node.op_type}: "
4647
f"{[i for i in node.input if i not in number]}"
4748
)
48-
mx = max(number[i] for i in node.input) + 1
49-
results[i_node] = mx
49+
if node.input:
50+
mx = max(number[i] for i in node.input) + 1
51+
results[i_node] = mx
52+
else:
53+
# A constant
54+
mx = max(number.values()) if number else 0
5055
for i in node.output:
5156
number[i] = mx
5257
return results
@@ -80,7 +85,7 @@ def graph_positions(
8085
indices = [i for i, o in enumerate(order) if o == row]
8186
assert indices, f"indices cannot be empty for row={row}, order={order}"
8287
ns = [nodes[i] for i in indices]
83-
mx = [max(names.get(i, 0) for i in n.input) for n in ns]
88+
mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
8489
mix = [(m, i) for i, m in enumerate(mx)]
8590
mix.sort()
8691
for c, (_m, i) in enumerate(mix):
@@ -108,14 +113,18 @@ def text_positions(
108113
for i, (_row, col) in enumerate(new_positions):
109114
size = len(nodes[i].op_type) + 5
110115
column_size[col] = max(column_size[col], size)
116+
assert column_size[col] < 200, (
117+
f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
118+
f"nodes[i].op_type={nodes[i].op_type}"
119+
)
111120

112121
# cumulated
113122
sort = sorted(column_size.items())
114123
cumul = dict(sort[:1])
115124
results = {sort[0][0]: sort[0][1] // 2}
116125
for col, size in sort[1:]:
117-
c = sum(cumul.values())
118-
cumul[col] = c
126+
c = max(cumul.values())
127+
cumul[col] = c + size
119128
results[col] = c + size // 2
120129
return [(row, results[col]) for row, col in new_positions]
121130

@@ -190,6 +199,49 @@ def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int
190199
edges.add(edge)
191200
return edges
192201

202+
ADD_RULES = {
203+
("┴", "┘"): "┴",
204+
("┴", "└"): "┴",
205+
("┬", "┐"): "┬",
206+
("┬", "┌"): "┬",
207+
("-", "└"): "┴",
208+
("-", "|"): "┼",
209+
("-", "┐"): "┬",
210+
("┐", "-"): "┬",
211+
("┘", "-"): "┴",
212+
("┴", "-"): "┴",
213+
("-", "┘"): "┴",
214+
("┌", "-"): "┬",
215+
("┬", "-"): "┬",
216+
("-", "┌"): "┬",
217+
("|", "-"): "┼",
218+
("└", "-"): "┴",
219+
("|", "└"): "├",
220+
("|", "┘"): "┤",
221+
("┐", "|"): "┤",
222+
("┬", "|"): "┼",
223+
("|", "┐"): "┤",
224+
("|", "┌"): "├",
225+
("├", "-"): "┼",
226+
("└", "|"): "├",
227+
("┤", "┐"): "┤",
228+
("┤", "|"): "┤",
229+
("├", "|"): "├",
230+
("┴", "┌"): "┼",
231+
("┐", "┌"): "┬",
232+
("┌", "┐"): "┬",
233+
("┌", "|"): "┼",
234+
("┴", "┐"): "┼",
235+
("┐", "└"): "┼",
236+
("┬", "┘"): "┼",
237+
("├", "└"): "├",
238+
("┤", "┌"): "┼",
239+
("┘", "|"): "┤",
240+
("┴", "|"): "┼",
241+
("┤", "-"): "┼",
242+
("┘", "└"): "┴",
243+
}
244+
193245
@classmethod
194246
def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
195247
"""
@@ -230,24 +282,44 @@ def text_edge(
230282
assert (
231283
0 <= p2[1] < min(len(g) for g in grid)
232284
), f"p2={p2}, sizes={[len(g) for g in grid]}"
233-
grid[p1[0] + 1][p1[1]] = "|"
234-
grid[p1[0] + 2][p1[1]] = "+"
285+
286+
def add(s1, s2):
287+
assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
288+
if s1 == " " or s1 == s2:
289+
return s2
290+
if s1 == "┼" or s2 == "┼":
291+
return "┼"
292+
if (s1, s2) in cls.ADD_RULES:
293+
return cls.ADD_RULES[s1, s2]
294+
raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
295+
296+
def place(grid, x, y, symbol):
297+
grid[x][y] = add(grid[x][y], symbol)
298+
299+
place(grid, p1[0] + 1, p1[1], "|")
300+
place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
235301

236302
if p1[0] + 2 == p2[0] - 2:
237303
a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
238-
grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1)
304+
for i in range(a, b + 1):
305+
place(grid, p1[0] + 2, i, "-")
239306
else:
240307
middle = (p1[1] + p2[1]) // 2
241308
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
242-
grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1)
309+
for i in range(a, b + 1):
310+
place(grid, p1[0] + 2, i, "-")
243311
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
244-
grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1)
312+
for i in range(a, b + 1):
313+
place(grid, p1[0] + 2, i, "-")
314+
315+
place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
316+
place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
245317

246-
grid[p1[0] + 2][middle] = "+"
247-
grid[p2[0] - 2][middle] = "+"
318+
for i in range(p1[0] + 2 + 1, p2[0] - 2):
319+
place(grid, i, middle, "|")
248320

249-
grid[p2[0] - 2][p2[1]] = "+"
250-
grid[p2[0] - 1][p2[1]] = "|"
321+
place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
322+
place(grid, p2[0] - 1, p2[1], "|")
251323

252324
def text_rendering(self, prefix="") -> str:
253325
"""
@@ -298,6 +370,7 @@ def text_rendering(self, prefix="") -> str:
298370
text_pos = self.text_positions(nodes, positions)
299371
edges = self.build_node_edges(nodes)
300372
max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
373+
assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
301374
max_row = max(row for row, _ in text_pos) + 2
302375
grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
303376

0 commit comments

Comments
 (0)