Skip to content

Commit 4515651

Browse files
committed
fix other series of bugs
1 parent 384166b commit 4515651

File tree

6 files changed

+67
-85
lines changed

6 files changed

+67
-85
lines changed
1.39 KB
Binary file not shown.

_unittests/ut_torch_onnx/test_discrepancies.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_attention_opset15_in_a_loop(self):
1212
)
1313
sess = self.check_ort(model)
1414
feeds = dict(
15-
c_lift_tensor_0=np.array([0], dtype=np.int64),
15+
c_lifted_tensor_0=np.array([0], dtype=np.int64),
1616
cat_2=np.array(
1717
[
1818
0,
@@ -45,9 +45,12 @@ def test_attention_opset15_in_a_loop(self):
4545
dtype=np.int64,
4646
),
4747
unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32),
48+
unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32),
49+
unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32),
4850
)
4951
got = sess.run(None, feeds)
5052
self.assertEqual(len(got), 1)
53+
self.assertEqual((1, 1292, 16, 80), got[0].shape)
5154

5255

5356
if __name__ == "__main__":

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,9 @@
1-
from typing import Dict, Set
1+
from typing import Dict
22
import numpy as np
33
import onnx
44
import onnx.numpy_helper as onh
55
from ..reference import ExtendedReferenceEvaluator as Inference
6-
from .onnx_helper import onnx_dtype_name, pretty_onnx
7-
8-
9-
def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
10-
hidden = set()
11-
memo = (
12-
{i.name for i in graph.initializer}
13-
| {i.values.name for i in graph.sparse_initializer}
14-
| {i.name for i in graph.input}
15-
)
16-
for node in graph.node:
17-
for i in node.input:
18-
if i not in memo:
19-
hidden.add(i)
20-
for att in node.attribute:
21-
if att.type == onnx.AttributeProto.GRAPH and att.g:
22-
hid = _get_hidden_inputs(att.g)
23-
less = set(h for h in hid if h not in memo)
24-
hidden |= less
25-
memo |= set(node.output)
26-
return hidden
6+
from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs
277

288

299
def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
@@ -221,7 +201,7 @@ def _mkn(obj: object) -> int:
221201
unique = set()
222202
for att in node.attribute:
223203
if att.type == onnx.AttributeProto.GRAPH:
224-
unique |= _get_hidden_inputs(att.g)
204+
unique |= get_hidden_inputs(att.g)
225205
for i in unique:
226206
edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
227207
if edge in done:

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,30 @@ def shadowing_names(
11981198
return shadow, post_shadow, created
11991199

12001200

1201+
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
1202+
"""
1203+
Returns the hidden inputs (inputs coming from an upper context)
1204+
used by a subgraph.
1205+
"""
1206+
hidden = set()
1207+
memo = (
1208+
set(i.name for i in graph.initializer)
1209+
| set(i.name for i in graph.sparse_initializer)
1210+
| set(i.name for i in graph.input)
1211+
)
1212+
for node in graph.node:
1213+
for i in node.input:
1214+
if i not in memo:
1215+
hidden.add(i)
1216+
for att in node.attribute:
1217+
if att.type == onnx.AttributeProto.GRAPH and att.g:
1218+
hid = get_hidden_inputs(att.g)
1219+
less = set(h for h in hid if h not in memo)
1220+
hidden |= less
1221+
memo |= set(node.output)
1222+
return hidden
1223+
1224+
12011225
def extract_subset_of_nodes(
12021226
model: ModelProto,
12031227
name: str,
@@ -1240,30 +1264,45 @@ def extract_subset_of_nodes(
12401264
current_input_index = 0
12411265
intermediate = {name}
12421266
cut_points -= {name}
1267+
cached = {}
12431268
inputs = set(k for k in node.input if k)
12441269
while not (inputs <= cut_points) and current_node_index >= 0:
12451270
node = model.graph.node[current_node_index]
1246-
if current_input_index == 0 or not node.input:
1271+
# node inputs including hidden ones
1272+
if current_node_index in cached:
1273+
node_inputs = cached[current_node_index]
1274+
else:
1275+
node_inputs = set(i for i in node.input if i)
1276+
if node.op_type in {"Scan", "If", "Loop"}:
1277+
# there are hidden inputs
1278+
for att in node.attribute:
1279+
if att.type == onnx.AttributeProto.GRAPH:
1280+
node_inputs |= get_hidden_inputs(att.g)
1281+
node_inputs = list(node_inputs)
1282+
cached[current_node_index] = node_inputs
1283+
# processing
1284+
if current_input_index == 0 or not node_inputs:
12471285
needs = [o for o in node.output if o in intermediate and o not in cut_points]
12481286
if needs:
12491287
selected.add(current_node_index)
1250-
if not node.input:
1288+
if not node_inputs:
12511289
current_node_index -= 1
12521290
current_input_index = 0
12531291
continue
12541292
else:
12551293
current_node_index -= 1
12561294
current_input_index = 0
12571295
continue
1258-
assert current_input_index < len(node.input), (
1259-
f"current_input_index={current_input_index} but node.input={node.input}, "
1296+
# more intermediate results
1297+
assert current_input_index < len(node_inputs), (
1298+
f"current_input_index={current_input_index} but node_inputs={node_inputs}, "
12601299
f"node={pretty_onnx(node)}"
12611300
)
1262-
res = node.input[current_input_index]
1301+
res = node_inputs[current_input_index]
12631302
if res not in cut_points:
12641303
intermediate.add(res)
12651304
current_input_index += 1
1266-
if current_input_index >= len(node.input):
1305+
if current_input_index >= len(node_inputs):
12671306
current_node_index -= 1
12681307
current_input_index = 0
12691308

@@ -1296,8 +1335,14 @@ def _mkv_(name, itype, irank):
12961335

12971336
not_known: Set[str] = set()
12981337
for node in nodes[::-1]:
1299-
not_known -= set(node.output)
1300-
not_known |= set(node.input)
1338+
not_known -= {o for o in node.output if o}
1339+
not_known |= {i for i in node.input if i}
1340+
if node.op_type in {"Scan", "If", "Loop"}:
1341+
# there are hidden inputs
1342+
for att in node.attribute:
1343+
if att.type == onnx.AttributeProto.GRAPH:
1344+
print("++++", get_hidden_inputs(att.g))
1345+
not_known |= get_hidden_inputs(att.g)
13011346

13021347
model = oh.make_model(
13031348
oh.make_graph(

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
import onnxruntime
1919
from ..helpers import string_type
2020
from ..helpers.onnx_helper import (
21-
pretty_onnx,
21+
get_hidden_inputs,
2222
dtype_to_tensor_dtype,
23-
to_array_extended,
2423
np_dtype_to_tensor_dtype,
24+
to_array_extended,
25+
pretty_onnx,
2526
)
2627
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
2728
from ..helpers.ort_session import (
@@ -472,39 +473,15 @@ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
472473
yield from self.enumerate_nodes(att.g.node)
473474
yield node
474475

475-
@classmethod
476-
def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
477-
"""
478-
Returns the hidden inputs (inputs coming from an upper context)
479-
used by a subgraph.
480-
"""
481-
hidden = set()
482-
memo = (
483-
{i.name for i in graph.initializer}
484-
| {i.name for i in graph.sparse_initializer}
485-
| {i.name for i in graph.input}
486-
)
487-
for node in graph.node:
488-
for i in node.input:
489-
if i not in memo:
490-
hidden.add(i)
491-
for att in node.attribute:
492-
if att.type == AttributeProto.GRAPH and att.g:
493-
hid = cls._get_hidden_inputs(att.g)
494-
less = set(h for h in hid if h not in memo)
495-
hidden |= less
496-
memo |= set(node.output)
497-
return hidden
498-
499476
@classmethod
500477
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
501-
"""Calls multiple _get_hidden_inputs on every attribute."""
478+
"""Calls multiple get_hidden_inputs on every attribute."""
502479
if node.op_type not in {"Loop", "Scan", "If"}:
503480
return set()
504481
hidden = set()
505482
for att in node.attribute:
506483
if att.type == AttributeProto.GRAPH:
507-
hidden |= cls._get_hidden_inputs(att.g)
484+
hidden |= get_hidden_inputs(att.g)
508485
return hidden - (hidden & set(node.input))
509486

510487
def _get_sess(
@@ -624,7 +601,7 @@ def _get_sess_init_subgraph(
624601
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
625602
vinputs.append(value)
626603

627-
reduced_set = self._get_hidden_inputs(g)
604+
reduced_set = get_hidden_inputs(g)
628605
for i, v in context.items():
629606
if i in reduced_set and i not in unique_names:
630607
unique_names.add(i)

onnx_diagnostic/torch_onnx/runtime_info.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from ..api import TensorLike
66
from ..helpers import string_type
7+
from ..helpers.onnx_helper import get_hidden_inputs
78

89

910
class RuntimeValueKind(enum.IntEnum):
@@ -151,30 +152,6 @@ def is_initializer(self) -> bool:
151152
return self.kind == RuntimeValueKind.INITIALIZER
152153

153154

154-
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
155-
"""
156-
Returns the hidden inputs (inputs coming from an upper context)
157-
used by a subgraph.
158-
"""
159-
hidden = set()
160-
memo = (
161-
set(i.name for i in graph.initializer)
162-
| set(i.name for i in graph.sparse_initializer)
163-
| set(i.name for i in graph.input)
164-
)
165-
for node in graph.node:
166-
for i in node.input:
167-
if i not in memo:
168-
hidden.add(i)
169-
for att in node.attribute:
170-
if att.type == onnx.AttributeProto.GRAPH and att.g:
171-
hid = get_hidden_inputs(att.g)
172-
less = set(h for h in hid if h not in memo)
173-
hidden |= less
174-
memo |= set(node.output)
175-
return hidden
176-
177-
178155
def set_is_shape(
179156
node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
180157
) -> List[str]:

0 commit comments

Comments
 (0)