Skip to content

Commit 525ae26

Browse files
authored
Add OnnxruntimeEvaluator (#9)
* Add OnnxruntimeEvaluator * ci * mypy * fix with torch * fix bfloat16 * doc * mypy * onnx * backend * fix a few things * back * fix a few things * disable some test again * disable some test again
1 parent 898ca10 commit 525ae26

19 files changed

+1415
-60
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,19 @@ jobs:
5656
run: |
5757
pip install pytest
5858
export PYTHONPATH=.
59-
UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py
59+
UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6060
export PYTHONPATH=
6161
62-
- name: run backend tests
62+
- name: run backend tests python
6363
run: |
6464
pip install pytest
6565
export PYTHONPATH=.
6666
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py
6767
export PYTHONPATH=
68+
69+
- name: run backend tests onnxruntime
70+
run: |
71+
pip install pytest
72+
export PYTHONPATH=.
73+
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_onnxruntime_evaluator.py --maxfail=15
74+
export PYTHONPATH=

.github/workflows/documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
pip install pytest
6060
pip install pytest-cov
6161
export PYTHONPATH=.
62-
UNITTEST_GOING=1 pytest --cov=./onnx_diagnostic/ --cov-report=xml --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py
62+
UNITTEST_GOING=1 pytest --cov=./onnx_diagnostic/ --cov-report=xml --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6363
export PYTHONPATH=
6464
6565
- name: Upload coverage reports to Codecov

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`9`: adds ``OnnxruntimeEvaluator``
78
* :pr:`8`: adds ``ExtendedReferenceEvaluator``
89
* :pr:`7`: improves function ``investigate_onnxruntime_issue``
910

README.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ onnx-diagnostic: investigate onnx models
1111
.. image:: https://badge.fury.io/py/onnx-diagnostic.svg
1212
:target: http://badge.fury.io/py/onnx-diagnostic
1313

14-
.. image:: http://img.shields.io/github/issues/sdpython/onnx-diagnostic.png
15-
:alt: GitHub Issues
16-
:target: https://github.com/sdpython/onnx-diagnostic/issues
17-
1814
.. image:: https://img.shields.io/badge/license-MIT-blue.svg
1915
:alt: MIT License
2016
:target: https://opensource.org/license/MIT/

_doc/api/reference/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@ onnx_diagnostic.reference
1313

1414
evaluator
1515
quantized_tensor
16+
ort_evaluator
1617

1718
ExtendedReferenceEvaluator
1819
++++++++++++++++++++++++++
1920

2021
.. autoclass:: onnx_diagnostic.reference.ExtendedReferenceEvaluator
2122
:members:
2223

24+
OnnxruntimeEvaluator
25+
++++++++++++++++++++
26+
27+
.. autoclass:: onnx_diagnostic.reference.OnnxruntimeEvaluator
28+
:members:
29+
2330
Other functions
2431
+++++++++++++++
2532

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.reference.ort_evaluator
3+
=======================================
4+
5+
.. automodule:: onnx_diagnostic.reference.ort_evaluator
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: OnnxruntimeEvaluator

_doc/conf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,12 @@
104104
("py:class", "False"),
105105
("py:class", "True"),
106106
("py:class", "Argument"),
107-
("py:class", "onnxscript.ir.Tuple"),
108-
("py:class", "pipeline.Pipeline"),
109107
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
110108
("py:class", "ModelProto"),
111109
("py:class", "Module"),
110+
("py:class", "np.ndarray"),
111+
("py:class", "onnxscript.ir.Tuple"),
112+
("py:class", "pipeline.Pipeline"),
112113
("py:class", "torch.fx.passes.operator_support.OperatorSupport"),
113114
("py:class", "torch.fx.proxy.TracerBase"),
114115
("py:class", "torch.utils._pytree.Context"),
@@ -177,6 +178,7 @@
177178
"GraphModule": "https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule",
178179
"HuggingFace": "https://huggingface.co/docs/hub/en/index",
179180
"Linux": "https://www.linux.org/",
181+
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
180182
"monai": "https://monai.io/",
181183
"numpy": "https://numpy.org/",
182184
"onnx": "https://onnx.ai/onnx/",
@@ -186,6 +188,7 @@
186188
"onnxrt backend": "https://pytorch.org/docs/stable/onnx_dynamo_onnxruntime_backend.html",
187189
"onnxruntime": "https://onnxruntime.ai/",
188190
"onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
191+
"onnxruntime kernels": "https://onnxruntime.ai/docs/reference/operators/OperatorKernels.html",
189192
"onnx-array-api": "https://sdpython.github.io/doc/onnx-array-api/dev/",
190193
"onnx-diagnostic": "https://sdpython.github.io/doc/onnx-diagnostic/dev/",
191194
"onnx-extended": "https://sdpython.github.io/doc/onnx-extended/dev/",
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
.. _l-plot-failing-onnxruntime-evaluator:
3+
4+
Running OnnxruntimeEvaluator on a failing model
5+
===============================================
6+
7+
Example :ref:`l-plot-failing-reference-evaluator` demonstrated
8+
how to run a python runtime on a model but it may very slow sometimes
9+
and it could show some discrepancies if the only provider is not CPU.
10+
Let's use :class:`OnnxruntimeEvaluator <onnx_diagnostic.reference.OnnxruntimeEvaluator>`.
11+
It splits the model into node and runs them independantly until it succeeds
12+
or fails. This class converts every node into model based on the types
13+
discovered during the execution. It relies on :class:`InferenceSessionForTorch
14+
<onnx_diagnostic.ort_session.InferenceSessionForTorch>` or
15+
:class:`InferenceSessionForNumpy
16+
<onnx_diagnostic.ort_session.InferenceSessionForNumpy>`
17+
for the execution. This example uses torch tensor and
18+
bfloat16.
19+
20+
A failing model
21+
+++++++++++++++
22+
23+
The issue here is a an operator ``Cast`` trying to convert a result
24+
into a non-existing type.
25+
"""
26+
27+
import onnx
28+
import onnx.helper as oh
29+
import torch
30+
import onnxruntime
31+
from onnx_diagnostic.ext_test_case import has_cuda
32+
from onnx_diagnostic.helpers import from_array_extended
33+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
34+
35+
TBFLOAT16 = onnx.TensorProto.BFLOAT16
36+
37+
model = oh.make_model(
38+
oh.make_graph(
39+
[
40+
oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"),
41+
oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"),
42+
oh.make_node("Add", ["sy", "one"], ["C"], name="n2"),
43+
oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"),
44+
oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"),
45+
],
46+
"nd",
47+
[
48+
oh.make_tensor_value_info("X", TBFLOAT16, ["a", "b", "c"]),
49+
oh.make_tensor_value_info("Y", TBFLOAT16, ["a", "b", "c"]),
50+
],
51+
[oh.make_tensor_value_info("Z", TBFLOAT16, ["a", "b", "c"])],
52+
[from_array_extended(torch.tensor([1], dtype=torch.bfloat16), name="one")],
53+
),
54+
opset_imports=[oh.make_opsetid("", 18)],
55+
ir_version=9,
56+
)
57+
58+
# %%
59+
# We check it is failing.
60+
61+
try:
62+
onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
63+
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
64+
print(e)
65+
66+
67+
# %%
68+
# OnnxruntimeEvaluator
69+
# ++++++++++++++++++++++++++
70+
#
71+
# This class extends :class:`onnx.reference.ReferenceEvaluator`
72+
# with operators outside the standard but defined by :epkg:`onnxruntime`.
73+
# `verbose=10` tells the class to print as much as possible,
74+
# `verbose=0` prints nothing. Intermediate values for more or less verbosity.
75+
76+
ref = OnnxruntimeEvaluator(model, verbose=10)
77+
feeds = dict(
78+
X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
79+
)
80+
try:
81+
ref.run(None, feeds)
82+
except Exception as e:
83+
print("ERROR", type(e), e)
84+
85+
86+
# %%
87+
# :epkg:`onnxruntime` may not support bfloat16 on CPU.
88+
# See :epkg:`onnxruntime kernels`.
89+
90+
if has_cuda():
91+
ref = OnnxruntimeEvaluator(model, providers="cuda", verbose=10)
92+
feeds = dict(
93+
X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
94+
)
95+
try:
96+
ref.run(None, feeds)
97+
except Exception as e:
98+
print("ERROR", type(e), e)
99+
100+
# %%
101+
# We can see it run until it reaches `Cast` and stops.
102+
# The error message is not always obvious to interpret.
103+
# It gets improved everytime from time to time.
104+
# This runtime is useful when it fails for a numerical reason.
105+
# It is possible to insert prints in the python code to print
106+
# more information or debug if needed.

_doc/index.rst

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,10 @@ onnx-diagnostic: investigate onnx models
88
.. image:: https://badge.fury.io/py/onnx-diagnostic.svg
99
:target: http://badge.fury.io/py/onnx-diagnostic
1010

11-
.. image:: http://img.shields.io/github/issues/sdpython/onnx-diagnostic.png
12-
:alt: GitHub Issues
13-
:target: https://github.com/sdpython/onnx-diagnostic/issues
14-
1511
.. image:: https://img.shields.io/badge/license-MIT-blue.svg
1612
:alt: MIT License
1713
:target: https://opensource.org/license/MIT/
1814

19-
.. image:: https://img.shields.io/github/repo-size/sdpython/onnx-diagnostic
20-
:target: https://github.com/sdpython/onnx-diagnostic/
21-
:alt: size
22-
2315
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
2416
:target: https://github.com/psf/black
2517

@@ -51,6 +43,7 @@ Source are `sdpython/onnx-diagnostic
5143
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
5244
* :ref:`l-plot-tiny-llm-export`
5345
* :ref:`l-plot-failing-reference-evaluator`
46+
* :ref:`l-plot-failing-onnxruntime-evaluator`
5447
* :ref:`l-plot-failing-model-extract`
5548

5649
**Some Usefuls Tools**

_unittests/ut_reference/test_array_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import numpy as np
33
from onnx import TensorProto
44
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5-
from onnx.reference.op_run import to_array_extended
65
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
7-
from onnx_diagnostic.helpers import from_array_extended
6+
from onnx_diagnostic.helpers import from_array_extended, to_array_extended
87
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
98

109

0 commit comments

Comments
 (0)