Skip to content

Commit bce84d2

Browse files
committed
fix bool inputs
1 parent 40e6698 commit bce84d2

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,34 @@ def test_constant_bool_array(self):
145145
self.assertIn(got.dtype, (torch.bool, np.bool_))
146146
self.assertEqual(got[0], True)
147147

148+
def test_constant_bool_input(self):
149+
node = oh.make_model(
150+
oh.make_graph(
151+
[oh.make_node("Identity", ["bin"], ["bout"])],
152+
"test",
153+
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
154+
[oh.make_tensor_value_info("bin", onnx.TensorProto.BOOL, [1])],
155+
),
156+
ir_version=10,
157+
opset_imports=[oh.make_opsetid("", 18)],
158+
)
159+
feeds = dict(bin=np.array([True], dtype=np.bool_))
160+
ref = ExtendedReferenceEvaluator(node)
161+
162+
got = ref.run(None, feeds)[0]
163+
self.assertEqual(got.dtype, np.bool_)
164+
self.assertEqual(got[0], True)
165+
166+
ref = OnnxruntimeEvaluator(node)
167+
got = ref.run(None, feeds)[0]
168+
self.assertEqual(got.dtype, np.bool_)
169+
self.assertEqual(got[0], True)
170+
171+
feeds = dict(bin=torch.tensor([True], dtype=torch.bool))
172+
got = ref.run(None, feeds)[0]
173+
self.assertEqual(got.dtype, torch.bool)
174+
self.assertEqual(got[0], True)
175+
148176

149177
if __name__ == "__main__":
150178
unittest.main(verbosity=2)

onnx_diagnostic/helpers/ort_session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,14 @@ def run_dlpack(
484484
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
485485
if not v.is_contiguous():
486486
v = v.contiguous()
487-
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
487+
if v.dtype == torch.bool:
488+
# It does not work with dlpack
489+
# unless onnxruntime updates the version it is using.
490+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
491+
v.detach().numpy(), onnx.TensorProto.BOOL
492+
)
493+
else:
494+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
488495
if self.nvtx:
489496
self.torch.cuda.nvtx.range_push("run_with_ort_values")
490497
ort_outputs = self.sess._sess.run_with_ort_values(

0 commit comments

Comments
 (0)