Skip to content

Commit 43b14f0

Browse files
committed
mypy
1 parent e521470 commit 43b14f0

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_constant_bool(self):
110110
got = ref.run(None, {})[0]
111111
self.assertEqual(got.dtype, np.bool_)
112112
self.assertEqual(got, True)
113-
ref = OnnxruntimeEvaluator(node)
113+
ref = OnnxruntimeEvaluator(node, opsets=21)
114114
got = ref.run(None, {})[0]
115115
self.assertEqual(len(ref._cache), 1)
116116
values = list(ref._cache.values())
@@ -133,7 +133,7 @@ def test_constant_bool_array(self):
133133
got = ref.run(None, {})[0]
134134
self.assertEqual(got.dtype, np.bool_)
135135
self.assertEqual(got[0], True)
136-
ref = OnnxruntimeEvaluator(node)
136+
ref = OnnxruntimeEvaluator(node, opsets=21)
137137
got = ref.run(None, {})[0]
138138
self.assertEqual(len(ref._cache), 1)
139139
values = list(ref._cache.values())
@@ -163,7 +163,7 @@ def test_constant_bool_input(self):
163163
self.assertEqual(got.dtype, np.bool_)
164164
self.assertEqual(got[0], True)
165165

166-
ref = OnnxruntimeEvaluator(node)
166+
ref = OnnxruntimeEvaluator(node, opsets=21)
167167
got = ref.run(None, feeds)[0]
168168
self.assertEqual(got.dtype, np.bool_)
169169
self.assertEqual(got[0], True)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _get_sess(
338338
# We force the type to be a boolean.
339339
ref = ExtendedReferenceEvaluator(node)
340340
cst = ref.run(None, {})[0]
341-
vinputs = []
341+
vinputs: List[ValueInfoProto] = []
342342
voutputs = [
343343
oh.make_tensor_value_info(
344344
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape

0 commit comments

Comments
 (0)