Skip to content

Commit 3c9d867

Browse files
committed
handle empty input
1 parent 3a3a7b0 commit 3c9d867

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,13 +1243,22 @@ def extract_subset_of_nodes(
12431243
inputs = set(k for k in node.input if k)
12441244
while not (inputs <= cut_points) and current_node_index >= 0:
12451245
node = model.graph.node[current_node_index]
1246-
if current_input_index == 0:
1246+
if current_input_index == 0 or not node.input:
12471247
needs = [o for o in node.output if o in intermediate and o not in cut_points]
12481248
if needs:
12491249
selected.add(current_node_index)
1250+
if not node.input:
1251+
current_node_index -= 1
1252+
current_input_index = 0
1253+
continue
12501254
else:
12511255
current_node_index -= 1
1256+
current_input_index = 0
12521257
continue
1258+
assert current_input_index < len(node.input), (
1259+
f"current_input_index={current_input_index} but node.input={node.input}, "
1260+
f"node={pretty_onnx(node)}"
1261+
)
12531262
res = node.input[current_input_index]
12541263
if res not in cut_points:
12551264
intermediate.add(res)
@@ -1294,8 +1303,8 @@ def _mkv_(name, itype, irank):
12941303
oh.make_graph(
12951304
nodes,
12961305
"submodel",
1297-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1298-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1306+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
1307+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
12991308
),
13001309
ir_version=ir_version,
13011310
opset_imports=opset_imports,

0 commit comments

Comments
 (0)