@@ -24,13 +24,13 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
2424 return hidden
2525
2626
27- def _make_node_label (node : onnx .NodeProto ) -> str :
27+ def _make_node_label (node : onnx .NodeProto , tiny_inits : Dict [ str , str ] ) -> str :
2828 els = [f"{ node .domain } .\\ n{ node .op_type } " if node .domain else node .op_type , "(" ]
29- ee = ["." if i else "" for i in node .input ]
29+ ee = [tiny_inits . get ( i , "." ) if i else "" for i in node .input ]
3030 for att in node .attribute :
3131 if att .name == "to" :
3232 ee .append (f"{ att .name } ={ onnx_dtype_name (att .i )} " )
33- elif att .name in {"to" , "axis" , "value_int" , "stash_type" }:
33+ elif att .name in {"to" , "axis" , "value_int" , "stash_type" , "start" , "end" }:
3434 ee .append (f"{ att .name } ={ att .i } " )
3535 elif att .name in {"value_float" }:
3636 ee .append (f"{ att .name } ={ att .f } " )
@@ -115,9 +115,12 @@ def _mkn(obj: object) -> int:
115115 model = onnx .shape_inference .infer_shapes (model )
116116
117117 op_type_colors = {
118- "Shape" : "#eeeeee " ,
118+ "Shape" : "#d2a81f " ,
119119 "MatMul" : "#ee9999" ,
120120 "Transpose" : "#ee99ee" ,
121+ "Reshape" : "#eeeeee" ,
122+ "Squeeze" : "#eeeeee" ,
123+ "Unsqueeze" : "#eeeeee" ,
121124 }
122125
123126 edge_label = {}
@@ -137,6 +140,7 @@ def _mkn(obj: object) -> int:
137140 outputs = list (model .graph .output )
138141 nodes = list (model .graph .node )
139142 inits = list (model .graph .initializer )
143+ tiny_inits = {}
140144 name_to_ids = {}
141145 for inp in inputs :
142146 if not inp .name :
@@ -150,17 +154,19 @@ def _mkn(obj: object) -> int:
150154 if len (shape ) == 0 or (len (shape ) == 1 and shape [0 ] < 10 ):
151155 a = onh .to_array (init )
152156 vals = f" = { a } " if len (shape ) == 0 else f"\\ n=[{ ', ' .join ([str (i ) for i in a ])} ]"
157+ tiny_inits [init .name ] = (
158+ str (a ) if len (shape ) == 0 else f"[{ ', ' .join ([str (i ) for i in a ])} ]"
159+ )
153160 else :
154- vals = ""
155- ls = f"{ onnx_dtype_name (init .data_type )} ({ ', ' .join (map (str ,shape ))} )"
156- rows .append (
157- f' i_{ _mkn (init )} [label="{ init .name } \\ n{ ls } { vals } ", fillcolor="#cccc00"];'
158- )
159- name_to_ids [init .name ] = f"i_{ _mkn (init )} "
160- edge_label [init .name ] = ls
161+ ls = f"{ onnx_dtype_name (init .data_type )} ({ ', ' .join (map (str ,shape ))} )"
162+ rows .append (
163+ f' i_{ _mkn (init )} [label="{ init .name } \\ n{ ls } { vals } ", fillcolor="#cccc00"];'
164+ )
165+ name_to_ids [init .name ] = f"i_{ _mkn (init )} "
166+ edge_label [init .name ] = ls
161167 for node in nodes :
162168 color = op_type_colors .get (node .op_type , "#cccccc" )
163- label = _make_node_label (node )
169+ label = _make_node_label (node , tiny_inits )
164170 rows .append (f' { node .op_type } _{ _mkn (node )} [label="{ label } ", fillcolor="{ color } "];' )
165171 name_to_ids .update ({o : f"{ node .op_type } _{ _mkn (node )} " for o in node .output if o })
166172
@@ -169,7 +175,7 @@ def _mkn(obj: object) -> int:
169175 for node in nodes :
170176 names = list (node .input )
171177 for i in names :
172- if not i :
178+ if not i or i in tiny_inits :
173179 continue
174180 if i not in name_to_ids :
175181 raise ValueError (f"Unable to find { i !r} \n { pretty_onnx (model )} " )
0 commit comments