Skip to content

Commit ce5c1bf

Browse files
committed
fin
1 parent d306f2e commit ce5c1bf

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,9 +1239,12 @@ def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
12391239
Returns input and hidden inputs of a node.
12401240
See :func:`get_hidden_inputs`.
12411241
"""
1242+
start = set(node.input)
12421243
if node.op_type in {"Scan", "Loop", "If"}:
1243-
return set(node.input) | get_hidden_inputs(node)
1244-
return set(node.input)
1244+
for att in node.attribute:
1245+
if att.type == onnx.AttributeProto.GRAPH:
1246+
start |= get_hidden_inputs(att.g)
1247+
return start
12451248

12461249

12471250
def extract_subset_of_nodes(
@@ -1613,8 +1616,7 @@ def select_model_inputs_outputs(
16131616
if not mod:
16141617
continue
16151618

1616-
hidden = get_hidden_inputs([node])
1617-
node_inputs = list(node.input) + list(hidden)
1619+
node_inputs = get_all_node_inputs(node)
16181620

16191621
nb += 1
16201622
for inp in node_inputs:

0 commit comments

Comments
 (0)