11from typing import Dict , List , Optional , Sequence , Set , Tuple , Union
22import onnx
3+ import onnx .helper as oh
34
45
56class GraphRendering :
@@ -146,6 +147,34 @@ def start_names(self) -> List[onnx.NodeProto]:
146147 )
147148 return [* input_names , * init_names ]
148149
150+ @property
151+ def input_names (self ) -> List [str ]:
152+ "Returns the list of input names."
153+ return (
154+ self .proto .input
155+ if isinstance (self .proto , onnx .FunctionProto )
156+ else [
157+ i .name
158+ for i in (
159+ self .proto if isinstance (self .proto , onnx .GraphProto ) else self .proto .graph
160+ ).input
161+ ]
162+ )
163+
164+ @property
165+ def output_names (self ) -> List [str ]:
166+ "Returns the list of output names."
167+ return (
168+ self .proto .output
169+ if isinstance (self .proto , onnx .FunctionProto )
170+ else [
171+ i .name
172+ for i in (
173+ self .proto if isinstance (self .proto , onnx .GraphProto ) else self .proto .graph
174+ ).output
175+ ]
176+ )
177+
149178 @classmethod
150179 def build_node_edges (cls , nodes : Sequence [onnx .NodeProto ]) -> Set [Tuple [int , int ]]:
151180 """Builds the list of edges between nodes."""
@@ -210,7 +239,7 @@ def text_edge(
210239 else :
211240 middle = (p1 [1 ] + p2 [1 ]) // 2
212241 a , b = (p1 [1 ] + 1 , middle - 1 ) if p1 [1 ] < middle else (middle + 1 , p1 [1 ] - 1 )
213- grid [p1 [0 ] + 2 ][a : b + 1 ] = ["-" * (b - a + 1 )]
242+ grid [p1 [0 ] + 2 ][a : b + 1 ] = ["-" ] * (b - a + 1 )
214243 a , b = (p1 [1 ] + 1 , middle - 1 ) if p1 [1 ] < middle else (middle + 1 , p1 [1 ] - 1 )
215244 grid [p1 [0 ] + 2 ][a : b + 1 ] = ["-" ] * (b - a + 1 )
216245
@@ -223,7 +252,6 @@ def text_edge(
223252 def text_rendering (self , prefix = "" ) -> str :
224253 """
225254 Renders a model in text.
226- It only renders nodes.
227255
228256 .. runpython::
229257 :showcode:
@@ -257,8 +285,14 @@ def text_rendering(self, prefix="") -> str:
257285 text = textwrap.dedent(graph.text_rendering()).strip("\\ n")
258286 print(text)
259287 """
260- nodes = self .nodes
261- existing = self .start_names
288+ nodes = [
289+ * [oh .make_node (i , ["BEGIN" ], [i ]) for i in self .input_names ],
290+ * self .nodes ,
291+ * [oh .make_node (i , [i ], ["END" ]) for i in self .output_names ],
292+ ]
293+ exist = set (self .start_names ) - set (self .input_names )
294+ exist |= {"BEGIN" }
295+ existing = sorted (exist )
262296 order = self .computation_order (nodes , existing )
263297 positions = self .graph_positions (nodes , order , existing )
264298 text_pos = self .text_positions (nodes , positions )
@@ -269,8 +303,10 @@ def text_rendering(self, prefix="") -> str:
269303
270304 for n1 , n2 in edges :
271305 self .text_edge (grid , text_pos [n1 ], text_pos [n2 ])
306+ assert len (set (len (g ) for g in grid )) == 1 , f"lengths={ [len (g ) for g in grid ]} "
272307 for node , pos in zip (nodes , text_pos ):
273308 self .text_grid (grid , pos , node .op_type )
309+ assert len (set (len (g ) for g in grid )) == 1 , f"lengths={ [len (g ) for g in grid ]} "
274310
275311 return "\n " .join (
276312 f"{ prefix } { line .rstrip ()} " for line in ["" .join (line ) for line in grid ]
0 commit comments