Skip to content

Commit 40e6698

Browse files
committed
fix bug in onnxruntime
1 parent c13f243 commit 40e6698

File tree

3 files changed

+98
-18
lines changed

3 files changed

+98
-18
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import unittest
2+
import numpy as np
23
import onnx
4+
import onnx.helper as oh
35
import torch
46
import onnxruntime
57
from onnx_diagnostic.ext_test_case import ExtTestCase
8+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
69
from onnx_diagnostic.reference import OnnxruntimeEvaluator, ExtendedReferenceEvaluator
710

811
try:
@@ -96,6 +99,52 @@ def false_fn(x, y):
9699
for e, g in zip(expected, got):
97100
self.assertEqualArray(e, g, atol=1e-5)
98101

102+
def test_constant_bool(self):
103+
node = oh.make_node(
104+
"Constant",
105+
[],
106+
["cbool"],
107+
value=from_array_extended(np.array(True, dtype=np.bool_)),
108+
)
109+
ref = ExtendedReferenceEvaluator(node)
110+
got = ref.run(None, {})[0]
111+
self.assertEqual(got.dtype, np.bool_)
112+
self.assertEqual(got, True)
113+
ref = OnnxruntimeEvaluator(node)
114+
got = ref.run(None, {})[0]
115+
self.assertEqual(len(ref._cache), 1)
116+
values = list(ref._cache.values())
117+
_, sess = values[0]
118+
got2 = sess.run(None, {})[0]
119+
self.assertIn(got2.dtype, (torch.bool, np.bool_))
120+
self.assertEqual(got2, True)
121+
122+
self.assertIn(got.dtype, (torch.bool, np.bool_))
123+
self.assertEqual(got, True)
124+
125+
def test_constant_bool_array(self):
126+
node = oh.make_node(
127+
"Constant",
128+
[],
129+
["cbool"],
130+
value=from_array_extended(np.array([True], dtype=np.bool_)),
131+
)
132+
ref = ExtendedReferenceEvaluator(node)
133+
got = ref.run(None, {})[0]
134+
self.assertEqual(got.dtype, np.bool_)
135+
self.assertEqual(got[0], True)
136+
ref = OnnxruntimeEvaluator(node)
137+
got = ref.run(None, {})[0]
138+
self.assertEqual(len(ref._cache), 1)
139+
values = list(ref._cache.values())
140+
_, sess = values[0]
141+
got2 = sess.run(None, {})[0]
142+
self.assertIn(got2.dtype, (torch.bool, np.bool_))
143+
self.assertEqual(got2[0], True)
144+
145+
self.assertIn(got.dtype, (torch.bool, np.bool_))
146+
self.assertEqual(got[0], True)
147+
99148

100149
if __name__ == "__main__":
101150
unittest.main(verbosity=2)

onnx_diagnostic/helpers/ort_session.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
)
155155

156156
self._torch_from_dlpack = _from_dlpack
157+
self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()]
157158

158159
def run(
159160
self,
@@ -166,7 +167,19 @@ def run(
166167
ort_outputs = self.sess._sess.run_with_ort_values(
167168
feeds, output_names or self.output_names, self.run_options
168169
)
169-
return ort_outputs
170+
return self._post_process_inplace(ort_outputs)
171+
172+
def _post_process_inplace(self, outputs):
173+
for i in range(len(outputs)):
174+
o = outputs[i]
175+
if self.sess_bool_outputs[i]:
176+
if isinstance(o, np.ndarray):
177+
if o.dtype != np.bool_:
178+
outputs[i] = o.astype(np.bool_)
179+
else:
180+
if o.dtype != torch.bool:
181+
outputs[i] = o.to(torch.bool)
182+
return outputs
170183

171184

172185
class InferenceSessionForNumpy(_InferenceSession):
@@ -221,7 +234,7 @@ def run(
221234
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
222235
# sess.run does not support blfoat16
223236
# res = self.sess.run(output_names, feeds)
224-
return list(self.run_dlpack(output_names, feeds))
237+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
225238

226239
def run_dlpack(
227240
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
@@ -421,7 +434,7 @@ def run( # type: ignore
421434
if self.use_training_api:
422435
inputs = [feeds[i] for i in self.input_names]
423436
return self.run_training_api(*inputs, output_names=output_names)
424-
return self.run_dlpack(output_names, feeds)
437+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
425438

426439
def run_training_api(
427440
self, *inputs, output_names: Optional[List[str]] = None

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
ValueInfoProto,
1111
helper as oh,
1212
load,
13+
save as onnx_save,
14+
shape_inference as shi,
1315
)
1416
from onnx.defs import onnx_opset_version
1517
import onnxruntime
@@ -20,6 +22,7 @@
2022
InferenceSessionForNumpy,
2123
_InferenceSession,
2224
)
25+
from .evaluator import ExtendedReferenceEvaluator
2326

2427
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
2528
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
@@ -304,7 +307,7 @@ def _make_model_proto(
304307
return onx
305308

306309
@classmethod
307-
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str, Any]:
310+
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
308311
"""
309312
Returns the hidden inputs (inputs coming from an upper context)
310313
used by a subgraph.
@@ -331,18 +334,34 @@ def _get_sess(
331334
onx = node
332335
else:
333336
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
334-
unique_names = set()
335-
vinputs = []
336-
for i, it in zip(node.input, inputs):
337-
if i == "" or i in unique_names:
338-
continue
339-
unique_names.add(i)
340-
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
341-
vinputs.append(value)
337+
if node.op_type == "Constant":
338+
# We force the type to be a boolean.
339+
ref = ExtendedReferenceEvaluator(node)
340+
cst = ref.run(None, {})[0]
341+
vinputs = []
342+
voutputs = [
343+
oh.make_tensor_value_info(
344+
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
345+
)
346+
]
347+
else:
348+
unique_names = set()
349+
vinputs = []
350+
for i, it in zip(node.input, inputs):
351+
if i == "" or i in unique_names:
352+
continue
353+
unique_names.add(i)
354+
value = oh.make_tensor_value_info(
355+
i, dtype_to_tensor_dtype(it.dtype), it.shape
356+
)
357+
vinputs.append(value)
358+
359+
# no need to run shape inference
360+
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
342361

343-
# no need to run shape inference
344-
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
345362
onx = self._make_model_proto([node], vinputs, voutputs)
363+
# That helps fixing bugs.
364+
onx = shi.infer_shapes(onx)
346365

347366
cls = (
348367
InferenceSessionForNumpy
@@ -356,6 +375,7 @@ def _get_sess(
356375
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
357376
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
358377
) as e:
378+
onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx")
359379
raise RuntimeError(
360380
f"Unable to infer a session with inputs\n{string_type(inputs)}"
361381
f"\ndue to {e}\n{pretty_onnx(onx)}"
@@ -461,8 +481,7 @@ def _run_if(
461481
self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
462482

463483
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
464-
input_names = [i.name for i in sess.get_inputs()]
465-
feeds = {name: results[name] for name in input_names}
484+
feeds = {name: results[name] for name in sess.input_names}
466485
outputs = sess.run(None, feeds)
467486
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
468487
return outputs
@@ -515,8 +534,7 @@ def _run_scan(
515534
self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results)
516535

517536
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
518-
input_names = [i.name for i in sess.get_inputs()]
519-
feeds = {name: results[name] for name in input_names}
537+
feeds = {name: results[name] for name in sess.input_names}
520538
outputs = sess.run(None, feeds)
521539
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
522540
return outputs

0 commit comments

Comments
 (0)