Skip to content

Commit ccb38db

Browse files
committed
adds missing method
1 parent 5262b20 commit ccb38db

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
onnx_dtype_to_np_dtype,
2121
)
2222
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
23+
from onnx_diagnostic.helpers.ort_session import _InferenceSession
2324

2425
TFLOAT = TensorProto.FLOAT
2526

@@ -78,6 +79,18 @@ def test_ort_eval(self):
7879
self.assertEqualArray(expected, got, atol=1e-4)
7980
self.assertIn("Reshape(xm, shape3) -> Z", out)
8081

82+
@ignore_warnings(DeprecationWarning)
83+
def test__inference(self):
84+
model = self._get_model()
85+
86+
feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)}
87+
ref = ExtendedReferenceEvaluator(model)
88+
expected = ref.run(None, feeds)[0]
89+
90+
ort_eval = _InferenceSession(model)
91+
got = ort_eval.run(None, feeds)[0]
92+
self.assertEqualArray(expected, got, atol=1e-4)
93+
8194
@ignore_warnings(DeprecationWarning)
8295
@requires_cuda()
8396
@hide_stdout()

onnx_diagnostic/helpers/ort_session.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def __init__(
155155

156156
self._torch_from_dlpack = _from_dlpack
157157

158+
def run(
159+
self,
160+
output_names: Optional[List[str]],
161+
feeds: Union[Dict[str, np.ndarray], Dict[str, "OrtValue"]], # noqa: F821
162+
) -> Union[List[np.ndarray], List["OrtValue"]]: # noqa: F821
163+
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
164+
if any(isinstance(t, np.ndarray) for t in feeds.values()):
165+
return self.sess.run(output_names, feeds)
166+
ort_outputs = self.sess._sess.run_with_ort_values(
167+
feeds, output_names or self.output_names, self.run_options
168+
)
169+
return ort_outputs
170+
158171

159172
class InferenceSessionForNumpy(_InferenceSession):
160173
"""

0 commit comments

Comments
 (0)