11from typing import Dict , Set
2+ import numpy as np
23import onnx
34import onnx .numpy_helper as onh
5+ from ..reference import ExtendedReferenceEvaluator as Inference
46from .onnx_helper import onnx_dtype_name , pretty_onnx
57
68
@@ -142,14 +144,37 @@ def _mkn(obj: object) -> int:
142144 inits = list (model .graph .initializer )
143145 tiny_inits = {}
144146 name_to_ids = {}
147+
145148 for inp in inputs :
146149 if not inp .name :
147150 continue
148151 lab = _make_edge_label (inp )
149152 rows .append (f' I_{ _mkn (inp )} [label="{ inp .name } \\ n{ lab } ", fillcolor="#aaeeaa"];' )
150153 name_to_ids [inp .name ] = f"I_{ _mkn (inp )} "
151154 edge_label [inp .name ] = _make_edge_label (inp , multi_line = True )
155+
156+ # Small constant --> initializer
157+ for node in nodes :
158+ if node .op_type != "Constant" :
159+ continue
160+ skip = False
161+ for att in node .attribute :
162+ if att .name == "value" and (
163+ len (att .t .dims ) > 1 or np .prod (tuple (att .t .dims )) > 10
164+ ):
165+ skip = True
166+ break
167+ if skip :
168+ continue
169+
170+ sess = Inference (node )
171+ value = sess .run (None , {})[0 ]
172+ inits .append (onh .from_array (value , name = node .output [0 ]))
173+
152174 for init in inits :
175+ if init .name in inputs :
176+ # hide optional inputs
177+ continue
153178 shape = tuple (init .dims )
154179 if len (shape ) == 0 or (len (shape ) == 1 and shape [0 ] < 10 ):
155180 a = onh .to_array (init )
@@ -161,7 +186,10 @@ def _mkn(obj: object) -> int:
161186 rows .append (f' i_{ _mkn (init )} [label="{ init .name } \\ n{ ls } ", fillcolor="#cccc00"];' )
162187 name_to_ids [init .name ] = f"i_{ _mkn (init )} "
163188 edge_label [init .name ] = ls
189+
164190 for node in nodes :
191+ if node .op_type == "Constant" and node .output [0 ] in name_to_ids :
192+ continue
165193 color = op_type_colors .get (node .op_type , "#cccccc" )
166194 label = _make_node_label (node , tiny_inits )
167195 rows .append (f' { node .op_type } _{ _mkn (node )} [label="{ label } ", fillcolor="{ color } "];' )
0 commit comments