Skip to content

Commit c13f243

Browse files
committed
better onnxruntimeeval
1 parent 8760e20 commit c13f243

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def forward(self, x, y):
241241
dump_file,
242242
save_as_external_data=True,
243243
all_tensors_to_one_file=True,
244-
location=f"{os.path.split(dump_file)[-1]}.weight",
244+
location=f"{os.path.split(dump_file)[-1]}.data",
245245
)
246246
if verbose:
247247
print("-- done dump stored objects")

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
22
import numpy as np
33
from onnx import (
4+
AttributeProto,
45
GraphProto,
56
FunctionProto,
67
ModelProto,
@@ -250,7 +251,7 @@ def run(
250251
inputs = [(results[i] if i != "" else None) for i in node.input]
251252
if node.op_type == "If" and node.domain == "":
252253
outputs = self._run_if(node, inputs, results)
253-
elif node.op_type == "Scan" and node.domain == "":
254+
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
254255
outputs = self._run_scan(node, inputs, results)
255256
elif self._is_local_function(node):
256257
outputs = self._run_local(node, inputs, results)
@@ -302,6 +303,27 @@ def _make_model_proto(
302303

303304
return onx
304305

306+
@classmethod
307+
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str, Any]:
308+
"""
309+
Returns the hidden inputs (inputs coming from an upper context)
310+
used by a subgraph.
311+
"""
312+
hidden = set()
313+
memo = set(i.name for i in graph.initializer)
314+
memo |= set(i.name for i in graph.sparse_initializer)
315+
for node in graph.node:
316+
for i in node.input:
317+
if i not in memo:
318+
hidden.add(i)
319+
for att in node.attribute:
320+
if att.type == AttributeProto.GRAPH and att.g:
321+
hid = self._get_hidden_inputs(att.g)
322+
less = set(h for h in hid if h not in memo)
323+
hidden |= less
324+
memo |= set(node.output)
325+
return hidden
326+
305327
def _get_sess(
306328
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
307329
) -> Tuple[ModelProto, _InferenceSession]:
@@ -341,7 +363,7 @@ def _get_sess(
341363
return onx, sess
342364

343365
def _get_sess_init_subgraph(
344-
self, node: NodeProto, inputs: List[Any], context: Dict[str, Any]
366+
self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto
345367
) -> List[Any]:
346368
unique_names = set()
347369
vinputs = []
@@ -352,8 +374,9 @@ def _get_sess_init_subgraph(
352374
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
353375
vinputs.append(value)
354376

377+
reduced_set = self._get_hidden_inputs(g)
355378
for i, v in context.items():
356-
if i not in unique_names:
379+
if i in reduced_set and i not in unique_names:
357380
unique_names.add(i)
358381
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
359382
vinputs.append(value)
@@ -362,13 +385,12 @@ def _get_sess_init_subgraph(
362385
def _get_sess_if(
363386
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
364387
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
365-
vinputs = self._get_sess_init_subgraph(node, inputs, context)
366-
367388
g = None
368389
for att in node.attribute:
369390
if att.name == branch:
370391
g = att.g
371392
assert g, f"Missing attribute {branch!r}"
393+
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
372394

373395
voutputs = g.output
374396

@@ -439,20 +461,21 @@ def _run_if(
439461
self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
440462

441463
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
464+
input_names = [i.name for i in sess.get_inputs()]
465+
feeds = {name: results[name] for name in input_names}
442466
outputs = sess.run(None, feeds)
443467
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
444468
return outputs
445469

446470
def _get_sess_scan(
447471
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
448472
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
449-
vinputs = self._get_sess_init_subgraph(node, inputs, context)
450-
451473
g = None
452474
for att in node.attribute:
453475
if att.name == branch:
454476
g = att.g
455477
assert g, f"Missing attribute {branch!r}"
478+
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
456479

457480
voutputs = []
458481
for name, goutput in zip(node.output, g.output):
@@ -492,6 +515,8 @@ def _run_scan(
492515
self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results)
493516

494517
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
518+
input_names = [i.name for i in sess.get_inputs()]
519+
feeds = {name: results[name] for name in input_names}
495520
outputs = sess.run(None, feeds)
496521
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
497522
return outputs

0 commit comments

Comments
 (0)