11from typing import Dict , Set
22import onnx
3+ import onnx .numpy_helper as onh
34from .onnx_helper import onnx_dtype_name , pretty_onnx
45
56
@@ -44,6 +45,24 @@ def _make_node_label(node: onnx.NodeProto) -> str:
4445 return "" .join (els )
4546
4647
48+ def _make_edge_label (value_info : onnx .ValueInfoProto , multi_line : bool = False ) -> str :
49+ itype = value_info .type .tensor_type .elem_type
50+ if itype == onnx .TensorProto .UNDEFINED :
51+ return ""
52+ shape = tuple (
53+ d .dim_param if d .dim_param else d .dim_value
54+ for d in value_info .type .tensor_type .shape .dim
55+ )
56+ res = [
57+ str (a )
58+ for a in [("?" if isinstance (s , str ) and s .startswith ("unk" ) else s ) for s in shape ]
59+ ]
60+ sshape = "," .join (res )
61+ if multi_line and len (sshape ) > 30 :
62+ sshape = ",\\ n" .join (res )
63+ return f"{ onnx_dtype_name (itype )} ({ sshape } )"
64+
65+
4766def to_dot (model : onnx .ModelProto ) -> str :
4867 """
4968 Converts a model into a dot graph.
@@ -103,19 +122,7 @@ def _mkn(obj: object) -> int:
103122
104123 edge_label = {}
105124 for val in model .graph .value_info :
106- itype = val .type .tensor_type .elem_type
107- if itype == onnx .TensorProto .UNDEFINED :
108- continue
109- shape = tuple (
110- d .dim_param if d .dim_param else d .dim_value for d in val .type .tensor_type .shape .dim
111- )
112- sshape = "," .join (
113- map (
114- str ,
115- [("?" if isinstance (s , str ) and s .startswith ("unk" ) else s ) for s in shape ],
116- )
117- )
118- edge_label [val .name ] = f"{ onnx_dtype_name (itype )} ({ sshape } )"
125+ edge_label [val .name ] = _make_edge_label (val , multi_line = True )
119126
120127 rows = [
121128 "digraph {" ,
@@ -124,7 +131,7 @@ def _mkn(obj: object) -> int:
124131 "ranksep=0.2, fontsize=8];"
125132 ),
126133 ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];' ,
127- " edge [arrowhead=vee, fontsize=6 ];" ,
134+ " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0 ];" ,
128135 ]
129136 inputs = list (model .graph .input )
130137 outputs = list (model .graph .output )
@@ -134,11 +141,23 @@ def _mkn(obj: object) -> int:
134141 for inp in inputs :
135142 if not inp .name :
136143 continue
137- rows .append (f' I_{ _mkn (inp )} [label="{ inp .name } ", fillcolor="#aaeeaa"];' )
144+ lab = _make_edge_label (inp )
145+ rows .append (f' I_{ _mkn (inp )} [label="{ inp .name } \\ n{ lab } ", fillcolor="#aaeeaa"];' )
138146 name_to_ids [inp .name ] = f"I_{ _mkn (inp )} "
147+ edge_label [inp .name ] = _make_edge_label (inp , multi_line = True )
139148 for init in inits :
140- rows .append (f' i_{ _mkn (init )} [label="{ init .name } ", fillcolor="#cccc00"];' )
149+ shape = tuple (init .dims )
150+ if len (shape ) == 0 or (len (shape ) == 1 and shape [0 ] < 10 ):
151+ a = onh .to_array (init )
152+ vals = f" = { a } " if len (shape ) == 0 else f"\\ n=[{ ', ' .join ([str (i ) for i in a ])} ]"
153+ else :
154+ vals = ""
155+ ls = f"{ onnx_dtype_name (init .data_type )} ({ ', ' .join (map (str ,shape ))} )"
156+ rows .append (
157+ f' i_{ _mkn (init )} [label="{ init .name } \\ n{ ls } { vals } ", fillcolor="#cccc00"];'
158+ )
141159 name_to_ids [init .name ] = f"i_{ _mkn (init )} "
160+ edge_label [init .name ] = ls
142161 for node in nodes :
143162 color = op_type_colors .get (node .op_type , "#cccccc" )
144163 label = _make_node_label (node )
@@ -179,7 +198,8 @@ def _mkn(obj: object) -> int:
179198 for out in outputs :
180199 if not out .name :
181200 continue
182- rows .append (f' O_{ _mkn (out )} [label="{ out .name } ", fillcolor="#aaaaee"];' )
201+ lab = _make_edge_label (inp )
202+ rows .append (f' O_{ _mkn (out )} [label="{ out .name } \\ n{ lab } ", fillcolor="#aaaaee"];' )
183203 edge = name_to_ids [out .name ], f"O_{ _mkn (out )} "
184204 rows .append (f" { edge [0 ]} -> { edge [1 ]} ;" )
185205
0 commit comments