@@ -1243,13 +1243,22 @@ def extract_subset_of_nodes(
12431243 inputs = set (k for k in node .input if k )
12441244 while not (inputs <= cut_points ) and current_node_index >= 0 :
12451245 node = model .graph .node [current_node_index ]
1246- if current_input_index == 0 :
1246+ if current_input_index == 0 or not node . input :
12471247 needs = [o for o in node .output if o in intermediate and o not in cut_points ]
12481248 if needs :
12491249 selected .add (current_node_index )
1250+ if not node .input :
1251+ current_node_index -= 1
1252+ current_input_index = 0
1253+ continue
12501254 else :
12511255 current_node_index -= 1
1256+ current_input_index = 0
12521257 continue
1258+ assert current_input_index < len (node .input ), (
1259+ f"current_input_index={ current_input_index } but node.input={ node .input } , "
1260+ f"node={ pretty_onnx (node )} "
1261+ )
12531262 res = node .input [current_input_index ]
12541263 if res not in cut_points :
12551264 intermediate .add (res )
@@ -1294,8 +1303,8 @@ def _mkv_(name, itype, irank):
12941303 oh .make_graph (
12951304 nodes ,
12961305 "submodel" ,
1297- [_mkv_ (n , * type_rank_fn (n )) for n in sorted (not_known )],
1298- [_mkv_ (n , * type_rank_fn (n )) for n in sorted (output_names )],
1306+ [_mkv_ (n , * type_rank_fn (n )) for n in sorted (not_known ) if n ],
1307+ [_mkv_ (n , * type_rank_fn (n )) for n in sorted (output_names ) if n ],
12991308 ),
13001309 ir_version = ir_version ,
13011310 opset_imports = opset_imports ,
0 commit comments