File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -1239,9 +1239,12 @@ def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
12391239 Returns input and hidden inputs of a node.
12401240 See :func:`get_hidden_inputs`.
12411241 """
1242+ start = set (node .input )
12421243 if node .op_type in {"Scan" , "Loop" , "If" }:
1243- return set (node .input ) | get_hidden_inputs (node )
1244- return set (node .input )
1244+ for att in node .attribute :
1245+ if att .type == onnx .AttributeProto .GRAPH :
1246+ start |= get_hidden_inputs (att .g )
1247+ return start
12451248
12461249
12471250def extract_subset_of_nodes (
@@ -1613,8 +1616,7 @@ def select_model_inputs_outputs(
16131616 if not mod :
16141617 continue
16151618
1616- hidden = get_hidden_inputs ([node ])
1617- node_inputs = list (node .input ) + list (hidden )
1619+ node_inputs = get_all_node_inputs (node )
16181620
16191621 nb += 1
16201622 for inp in node_inputs :
You can’t perform that action at this time.
0 commit comments