Skip to content

Commit 3371069

Browse files
committed
Add OnnxruntimeEvaluator
1 parent 898ca10 commit 3371069

File tree

9 files changed

+690
-0
lines changed

9 files changed

+690
-0
lines changed

_doc/api/reference/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@ onnx_diagnostic.reference
1313

1414
evaluator
1515
quantized_tensor
16+
ort_evaluator
1617

1718
ExtendedReferenceEvaluator
1819
++++++++++++++++++++++++++
1920

2021
.. autoclass:: onnx_diagnostic.reference.ExtendedReferenceEvaluator
2122
:members:
2223

24+
OnnxruntimeEvaluator
25+
++++++++++++++++++++
26+
27+
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
28+
:members:
29+
2330
Other functions
2431
+++++++++++++++
2532

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.reference.ort_evaluator
3+
=======================================
4+
5+
.. automodule:: onnx_diagnostic.reference.ort_evaluator
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: OnnxruntimeEvaluator
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
.. _l-plot-failing-onnxruntime-evaluator:
3+
4+
Running OnnxruntimeEvaluator on a failing model
5+
===============================================
6+
7+
Example :ref:`l-plot-failing-reference-evaluator` demonstrated
8+
how to run a python runtime on a model but it may very slow sometimes
9+
and it could show some discrepancies if the only provider is not CPU.
10+
Let's use :class:`OnnxruntimeEvaluator <onnx_diagnostic.reference.OnnxruntimeEvaluator>`.
11+
It splits the model into node and runs them independantly until it succeeds
12+
or fails. This class converts every node into model based on the types
13+
discovered during the execution. It relies on :class:`InferenceSessionForTorch
14+
<onnx_diagnostic.ort_session.InferenceSessionForTorch>` or
15+
:class:`InferenceSessionForNumpy
16+
<onnx_diagnostic.ort_session.InferenceSessionForNumpy>`
17+
for the execution. This example uses torch tensor and
18+
bfloat16.
19+
20+
A failing model
21+
+++++++++++++++
22+
23+
The issue here is a an operator ``Cast`` trying to convert a result
24+
into a non-existing type.
25+
"""
26+
27+
import onnx
28+
import onnx.helper as oh
29+
import torch
30+
import onnxruntime
31+
from onnx_diagnostic.helpers import from_array_extended
32+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
33+
34+
TBFLOAT16 = onnx.TensorProto.BFLOAT16
35+
36+
model = oh.make_model(
37+
oh.make_graph(
38+
[
39+
oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"),
40+
oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"),
41+
oh.make_node("Add", ["sy", "one"], ["C"], name="n2"),
42+
oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"),
43+
oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"),
44+
],
45+
"nd",
46+
[
47+
oh.make_tensor_value_info("X", TBFLOAT16, ["a", "b", "c"]),
48+
oh.make_tensor_value_info("Y", TBFLOAT16, ["a", "b", "c"]),
49+
],
50+
[oh.make_tensor_value_info("Z", TBFLOAT16, ["a", "b", "c"])],
51+
[from_array_extended(torch.tensor([1], dtype=torch.bfloat16), name="one")],
52+
),
53+
opset_imports=[oh.make_opsetid("", 18)],
54+
ir_version=9,
55+
)
56+
57+
# %%
58+
# We check it is failing.
59+
60+
try:
61+
onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
62+
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
63+
print(e)
64+
65+
66+
# %%
67+
# OnnxruntimeEvaluator
68+
# ++++++++++++++++++++++++++
69+
#
70+
# This class extends :class:`onnx.reference.ReferenceEvaluator`
71+
# with operators outside the standard but defined by :epkg:`onnxruntime`.
72+
# `verbose=10` tells the class to print as much as possible,
73+
# `verbose=0` prints nothing. Intermediate values for more or less verbosity.
74+
75+
ref = OnnxruntimeEvaluator(model, verbose=10)
76+
feeds = dict(
77+
X=torch.rand((3, 4), dtype=torch.blofat16), Y=torch.rand((3, 4), dtype=torch.blofat16)
78+
)
79+
try:
80+
ref.run(None, feeds)
81+
except Exception as e:
82+
print("ERROR", type(e), e)
83+
84+
# %%
85+
# We can see it run until it reaches `Cast` and stops.
86+
# The error message is not always obvious to interpret.
87+
# It gets improved everytime from time to time.
88+
# This runtime is useful when it fails for a numerical reason.
89+
# It is possible to insert prints in the python code to print
90+
# more information or debug if needed.

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Source are `sdpython/onnx-diagnostic
5151
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
5252
* :ref:`l-plot-tiny-llm-export`
5353
* :ref:`l-plot-failing-reference-evaluator`
54+
* :ref:`l-plot-failing-onnxruntime-evaluator`
5455
* :ref:`l-plot-failing-model-extract`
5556

5657
**Some Usefuls Tools**
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import unittest
2+
from typing import Optional
3+
import numpy as np
4+
from onnx import ModelProto, TensorProto
5+
from onnx.checker import check_model
6+
import onnx.helper as oh
7+
import onnx.numpy_helper as onh
8+
import torch
9+
from onnx_diagnostic.ext_test_case import (
10+
ExtTestCase,
11+
hide_stdout,
12+
ignore_warnings,
13+
requires_cuda,
14+
)
15+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
16+
17+
TFLOAT = TensorProto.FLOAT
18+
19+
20+
class TestOnnxruntimeEvaluatoruator(ExtTestCase):
21+
def _range(self, *shape, bias: Optional[float] = None):
22+
n = np.prod(shape)
23+
x = np.arange(n).astype(np.float32) / n
24+
if bias:
25+
x = x + bias
26+
return x.reshape(tuple(shape)).astype(np.float32)
27+
28+
def _get_model(self) -> ModelProto:
29+
model = oh.make_model(
30+
oh.make_graph(
31+
[
32+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
33+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
34+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
35+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
36+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
37+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
38+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
39+
],
40+
"dummy",
41+
[
42+
oh.make_tensor_value_info("X", TFLOAT, [32, 128]),
43+
oh.make_tensor_value_info("Y", TFLOAT, [3, 5, 128, 64]),
44+
],
45+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 32, 64])],
46+
[
47+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
48+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
49+
onh.from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
50+
onh.from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
51+
onh.from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
52+
],
53+
),
54+
ir_version=9,
55+
opset_imports=[oh.make_opsetid("", 18)],
56+
)
57+
check_model(model)
58+
return model
59+
60+
@ignore_warnings(DeprecationWarning)
61+
def test_ort_eval(self):
62+
model = self._get_model()
63+
64+
feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)}
65+
ref = ExtendedReferenceEvaluator(model, verbose=10)
66+
expected, out, _ = self.capture(lambda: ref.run(None, feeds)[0])
67+
self.assertIn("Reshape(xm, shape3) -> Z", out)
68+
69+
ort_eval = OnnxruntimeEvaluator(model, verbose=10, opsets=20)
70+
got, out, _ = self.capture(lambda: ort_eval.run(None, feeds)[0])
71+
self.assertEqualArray(expected, got, atol=1e-4)
72+
self.assertIn("Reshape(xm, shape3) -> Z", out)
73+
74+
@ignore_warnings(DeprecationWarning)
75+
@requires_cuda()
76+
@hide_stdout()
77+
def test_ort_eval_cuda(self):
78+
model = self._get_model()
79+
80+
feeds = {"X": self._range(32, 128), "Y": self._range(3, 5, 128, 64)}
81+
ref = ExtendedReferenceEvaluator(model, verbose=10)
82+
expected = ref.run(None, feeds)[0]
83+
84+
ort_eval = OnnxruntimeEvaluator(model, verbose=10, opsets=20, providers="cuda")
85+
got = ort_eval.run(None, feeds)[0]
86+
self.assertEqualArray(expected, got, atol=1e-1)
87+
88+
@ignore_warnings(DeprecationWarning)
89+
@hide_stdout()
90+
def test_ort_eval_node_proto(self):
91+
model = self._get_model()
92+
93+
feeds = {"X": self._range(32, 128), "zero": np.array([0], dtype=np.int64)}
94+
ref = ExtendedReferenceEvaluator(model.graph.node[0], verbose=10)
95+
expected = ref.run(None, feeds)
96+
97+
ort_eval = OnnxruntimeEvaluator(model.graph.node[0], verbose=10, opsets=20)
98+
got = ort_eval.run(None, feeds)
99+
self.assertEqualArrayAny(expected, got, atol=1e-4)
100+
self.assertIsInstance(expected[0], np.ndarray)
101+
self.assertIsInstance(got[0], np.ndarray)
102+
103+
@ignore_warnings(DeprecationWarning)
104+
@hide_stdout()
105+
def test_ort_eval_node_proto_torch(self):
106+
model = self._get_model()
107+
108+
feeds_np = {"X": self._range(32, 128), "zero": np.array([0], dtype=np.int64)}
109+
feeds = {k: torch.from_numpy(v) for k, v in feeds_np.items()}
110+
ref = ExtendedReferenceEvaluator(model.graph.node[0], verbose=10)
111+
expected = ref.run(None, feeds_np)
112+
113+
ort_eval = OnnxruntimeEvaluator(model.graph.node[0], verbose=10, opsets=20)
114+
got = ort_eval.run(None, feeds)
115+
self.assertIsInstance(got[0], torch.Tensor)
116+
self.assertEqualArray(expected[0], got[0], atol=1e-4)
117+
118+
@hide_stdout()
119+
def test_local_function(self):
120+
new_domain = "custom"
121+
122+
linear_regression = oh.make_function(
123+
new_domain,
124+
"LinearRegression",
125+
["x", "a", "b"],
126+
["y"],
127+
[
128+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
129+
oh.make_node("Add", ["xa", "b"], ["y"]),
130+
],
131+
[oh.make_opsetid("", 14)],
132+
[],
133+
)
134+
135+
graph = oh.make_graph(
136+
[
137+
oh.make_node("LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain),
138+
oh.make_node("Abs", ["Y1"], ["Y"]),
139+
],
140+
"example",
141+
[
142+
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
143+
oh.make_tensor_value_info("A", TFLOAT, [None, None]),
144+
oh.make_tensor_value_info("B", TFLOAT, [None, None]),
145+
],
146+
[oh.make_tensor_value_info("Y", TFLOAT, None)],
147+
)
148+
149+
onnx_model = oh.make_model(
150+
graph,
151+
opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
152+
functions=[linear_regression],
153+
ir_version=10,
154+
)
155+
feeds = {
156+
"X": np.random.randn(3, 3).astype(np.float32),
157+
"A": np.random.randn(3, 3).astype(np.float32),
158+
"B": np.random.randn(3, 3).astype(np.float32),
159+
}
160+
ref = ExtendedReferenceEvaluator(onnx_model)
161+
ort_eval = OnnxruntimeEvaluator(onnx_model, verbose=10, opsets=20)
162+
expected = ref.run(None, feeds)
163+
got = ort_eval.run(None, feeds)
164+
self.assertEqualArray(expected[0], got[0])
165+
166+
167+
if __name__ == "__main__":
168+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ def test_rename_dynamic_expression(self):
407407
text = rename_dynamic_expression("a * 10 - a", {"a": "x"})
408408
self.assertEqual(text, "x * 10 - x")
409409

410+
def test_from_tensor(self):
411+
for dt in {
412+
torch.float32,
413+
torch.float64,
414+
torch.bfloat16,
415+
torch.float16,
416+
torch.int32,
417+
torch.int64,
418+
torch.int8,
419+
torch.int16,
420+
torch.uint8,
421+
torch.uint16,
422+
torch.uint32,
423+
torch.uint64,
424+
}:
425+
t = torch.rand((4, 3), dtype=torch.dtype)
426+
proto = from_array_extended(t)
427+
self.assertIsInstance(proto, onnx.TensorProto)
428+
convert_endian(proto)
429+
dtype_to_tensor_dtype(dt)
430+
410431

411432
if __name__ == "__main__":
412433
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,13 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
717717
:param name: name
718718
:return: TensorProto
719719
"""
720+
try:
721+
import torch
722+
except ImportError:
723+
torch = None
724+
if torch is not None and isinstance(tensor, torch.Tensor):
725+
raise NotImplementedError()
726+
720727
from onnx.reference.ops.op_cast import (
721728
bfloat16,
722729
float8e4m3fn,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .evaluator import ExtendedReferenceEvaluator
2+
from .ort_evaluator import OnnxruntimeEvaluator

0 commit comments

Comments
 (0)