|
| 1 | +import pprint |
1 | 2 | from typing import Dict, List, Optional, Sequence, Set, Tuple, Union |
2 | 3 | import onnx |
3 | 4 | import onnx.helper as oh |
@@ -45,8 +46,12 @@ def computation_order( |
45 | 46 | f"Missing input in node {i_node} type={node.op_type}: " |
46 | 47 | f"{[i for i in node.input if i not in number]}" |
47 | 48 | ) |
48 | | - mx = max(number[i] for i in node.input) + 1 |
49 | | - results[i_node] = mx |
| 49 | + if node.input: |
| 50 | + mx = max(number[i] for i in node.input) + 1 |
| 51 | + results[i_node] = mx |
| 52 | + else: |
| 53 | + # A constant |
| 54 | + mx = max(number.values()) if number else 0 |
50 | 55 | for i in node.output: |
51 | 56 | number[i] = mx |
52 | 57 | return results |
@@ -80,7 +85,7 @@ def graph_positions( |
80 | 85 | indices = [i for i, o in enumerate(order) if o == row] |
81 | 86 | assert indices, f"indices cannot be empty for row={row}, order={order}" |
82 | 87 | ns = [nodes[i] for i in indices] |
83 | | - mx = [max(names.get(i, 0) for i in n.input) for n in ns] |
| 88 | + mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns] |
84 | 89 | mix = [(m, i) for i, m in enumerate(mx)] |
85 | 90 | mix.sort() |
86 | 91 | for c, (_m, i) in enumerate(mix): |
@@ -108,14 +113,18 @@ def text_positions( |
108 | 113 | for i, (_row, col) in enumerate(new_positions): |
109 | 114 | size = len(nodes[i].op_type) + 5 |
110 | 115 | column_size[col] = max(column_size[col], size) |
| 116 | + assert column_size[col] < 200, ( |
| 117 | + f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, " |
| 118 | + f"nodes[i].op_type={nodes[i].op_type}" |
| 119 | + ) |
111 | 120 |
|
112 | 121 | # cumulated |
113 | 122 | sort = sorted(column_size.items()) |
114 | 123 | cumul = dict(sort[:1]) |
115 | 124 | results = {sort[0][0]: sort[0][1] // 2} |
116 | 125 | for col, size in sort[1:]: |
117 | | - c = sum(cumul.values()) |
118 | | - cumul[col] = c |
| 126 | + c = max(cumul.values()) |
| 127 | + cumul[col] = c + size |
119 | 128 | results[col] = c + size // 2 |
120 | 129 | return [(row, results[col]) for row, col in new_positions] |
121 | 130 |
|
@@ -190,6 +199,49 @@ def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int |
190 | 199 | edges.add(edge) |
191 | 200 | return edges |
192 | 201 |
|
| 202 | + ADD_RULES = { |
| 203 | + ("┴", "┘"): "┴", |
| 204 | + ("┴", "└"): "┴", |
| 205 | + ("┬", "┐"): "┬", |
| 206 | + ("┬", "┌"): "┬", |
| 207 | + ("-", "└"): "┴", |
| 208 | + ("-", "|"): "┼", |
| 209 | + ("-", "┐"): "┬", |
| 210 | + ("┐", "-"): "┬", |
| 211 | + ("┘", "-"): "┴", |
| 212 | + ("┴", "-"): "┴", |
| 213 | + ("-", "┘"): "┴", |
| 214 | + ("┌", "-"): "┬", |
| 215 | + ("┬", "-"): "┬", |
| 216 | + ("-", "┌"): "┬", |
| 217 | + ("|", "-"): "┼", |
| 218 | + ("└", "-"): "┴", |
| 219 | + ("|", "└"): "├", |
| 220 | + ("|", "┘"): "┤", |
| 221 | + ("┐", "|"): "┤", |
| 222 | + ("┬", "|"): "┼", |
| 223 | + ("|", "┐"): "┤", |
| 224 | + ("|", "┌"): "├", |
| 225 | + ("├", "-"): "┼", |
| 226 | + ("└", "|"): "├", |
| 227 | + ("┤", "┐"): "┤", |
| 228 | + ("┤", "|"): "┤", |
| 229 | + ("├", "|"): "├", |
| 230 | + ("┴", "┌"): "┼", |
| 231 | + ("┐", "┌"): "┬", |
| 232 | + ("┌", "┐"): "┬", |
| 233 | + ("┌", "|"): "┼", |
| 234 | + ("┴", "┐"): "┼", |
| 235 | + ("┐", "└"): "┼", |
| 236 | + ("┬", "┘"): "┼", |
| 237 | + ("├", "└"): "├", |
| 238 | + ("┤", "┌"): "┼", |
| 239 | + ("┘", "|"): "┤", |
| 240 | + ("┴", "|"): "┼", |
| 241 | + ("┤", "-"): "┼", |
| 242 | + ("┘", "└"): "┴", |
| 243 | + } |
| 244 | + |
193 | 245 | @classmethod |
194 | 246 | def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str): |
195 | 247 | """ |
@@ -230,24 +282,44 @@ def text_edge( |
230 | 282 | assert ( |
231 | 283 | 0 <= p2[1] < min(len(g) for g in grid) |
232 | 284 | ), f"p2={p2}, sizes={[len(g) for g in grid]}" |
233 | | - grid[p1[0] + 1][p1[1]] = "|" |
234 | | - grid[p1[0] + 2][p1[1]] = "+" |
| 285 | + |
| 286 | + def add(s1, s2): |
| 287 | + assert s2 != " ", f"s1={s1!r}, s2={s2!r}" |
| 288 | + if s1 == " " or s1 == s2: |
| 289 | + return s2 |
| 290 | + if s1 == "┼" or s2 == "┼": |
| 291 | + return "┼" |
| 292 | + if (s1, s2) in cls.ADD_RULES: |
| 293 | + return cls.ADD_RULES[s1, s2] |
| 294 | + raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',") |
| 295 | + |
| 296 | + def place(grid, x, y, symbol): |
| 297 | + grid[x][y] = add(grid[x][y], symbol) |
| 298 | + |
| 299 | + place(grid, p1[0] + 1, p1[1], "|") |
| 300 | + place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘") |
235 | 301 |
|
236 | 302 | if p1[0] + 2 == p2[0] - 2: |
237 | 303 | a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1) |
238 | | - grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1) |
| 304 | + for i in range(a, b + 1): |
| 305 | + place(grid, p1[0] + 2, i, "-") |
239 | 306 | else: |
240 | 307 | middle = (p1[1] + p2[1]) // 2 |
241 | 308 | a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1) |
242 | | - grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1) |
| 309 | + for i in range(a, b + 1): |
| 310 | + place(grid, p1[0] + 2, i, "-") |
243 | 311 | a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1) |
244 | | - grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1) |
| 312 | + for i in range(a, b + 1): |
| 313 | + place(grid, p1[0] + 2, i, "-") |
| 314 | + |
| 315 | + place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌") |
| 316 | + place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘") |
245 | 317 |
|
246 | | - grid[p1[0] + 2][middle] = "+" |
247 | | - grid[p2[0] - 2][middle] = "+" |
| 318 | + for i in range(p1[0] + 2 + 1, p2[0] - 2): |
| 319 | + place(grid, i, middle, "|") |
248 | 320 |
|
249 | | - grid[p2[0] - 2][p2[1]] = "+" |
250 | | - grid[p2[0] - 1][p2[1]] = "|" |
| 321 | + place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌") |
| 322 | + place(grid, p2[0] - 1, p2[1], "|") |
251 | 323 |
|
252 | 324 | def text_rendering(self, prefix="") -> str: |
253 | 325 | """ |
@@ -298,6 +370,7 @@ def text_rendering(self, prefix="") -> str: |
298 | 370 | text_pos = self.text_positions(nodes, positions) |
299 | 371 | edges = self.build_node_edges(nodes) |
300 | 372 | max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes) |
| 373 | + assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}" |
301 | 374 | max_row = max(row for row, _ in text_pos) + 2 |
302 | 375 | grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)] |
303 | 376 |
|
|
0 commit comments