Skip to content

Commit 7cca20c

Browse files
committed
ci
1 parent 3371069 commit 7cca20c

File tree

5 files changed

+268
-8
lines changed

5 files changed

+268
-8
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ jobs:
5656
run: |
5757
pip install pytest
5858
export PYTHONPATH=.
59-
UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py
59+
UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6060
export PYTHONPATH=
6161
6262
- name: run backend tests
6363
run: |
6464
pip install pytest
6565
export PYTHONPATH=.
66-
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py
66+
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6767
export PYTHONPATH=

.github/workflows/documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
pip install pytest
6060
pip install pytest-cov
6161
export PYTHONPATH=.
62-
UNITTEST_GOING=1 pytest --cov=./onnx_diagnostic/ --cov-report=xml --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py
62+
UNITTEST_GOING=1 pytest --cov=./onnx_diagnostic/ --cov-report=xml --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6363
export PYTHONPATH=
6464
6565
- name: Upload coverage reports to Codecov

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`9`: adds ``OnnxruntimeEvaluator``
78
* :pr:`8`: adds ``ExtendedReferenceEvaluator``
89
* :pr:`7`: improves function ``investigate_onnxruntime_issue``
910

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import sys
2+
import unittest
3+
import warnings
4+
from typing import Any
5+
import numpy
6+
import onnx.backend.base
7+
import onnx.backend.test
8+
import onnx.shape_inference
9+
import onnx.version_converter
10+
from onnx import ModelProto
11+
from onnx.backend.base import Device, DeviceType
12+
from onnx.defs import onnx_opset_version
13+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
14+
15+
16+
class OnnxruntimeEvaluatorBackendRep(onnx.backend.base.BackendRep):
17+
def __init__(self, session):
18+
self._session = session
19+
20+
def run(self, inputs, **kwargs):
21+
if isinstance(inputs, numpy.ndarray):
22+
inputs = [inputs]
23+
if isinstance(inputs, list):
24+
if len(inputs) == len(self._session.input_names):
25+
feeds = dict(zip(self._session.input_names, inputs))
26+
else:
27+
feeds = {}
28+
pos_inputs = 0
29+
for inp, tshape in zip(self._session.input_names, self._session.input_types):
30+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
31+
if shape == inputs[pos_inputs].shape:
32+
feeds[inp] = inputs[pos_inputs]
33+
pos_inputs += 1
34+
if pos_inputs >= len(inputs):
35+
break
36+
elif isinstance(inputs, dict):
37+
feeds = inputs
38+
else:
39+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
40+
outs = self._session.run(None, feeds)
41+
return outs
42+
43+
44+
class OnnxruntimeEvaluatorBackend(onnx.backend.base.Backend):
45+
@classmethod
46+
def is_opset_supported(cls, model): # pylint: disable=unused-argument
47+
return True, ""
48+
49+
@classmethod
50+
def supports_device(cls, device: str) -> bool:
51+
d = Device(device)
52+
return d.type == DeviceType.CPU # type: ignore[no-any-return]
53+
54+
@classmethod
55+
def create_inference_session(cls, model):
56+
return OnnxruntimeEvaluator(model)
57+
58+
@classmethod
59+
def prepare(
60+
cls, model: Any, device: str = "CPU", **kwargs: Any
61+
) -> OnnxruntimeEvaluatorBackendRep:
62+
if isinstance(model, OnnxruntimeEvaluator):
63+
return OnnxruntimeEvaluatorBackendRep(model)
64+
if isinstance(model, (str, bytes, ModelProto)):
65+
inf = cls.create_inference_session(model)
66+
return cls.prepare(inf, device, **kwargs)
67+
raise TypeError(f"Unexpected type {type(model)} for model.")
68+
69+
@classmethod
70+
def run_model(cls, model, inputs, device=None, **kwargs):
71+
rep = cls.prepare(model, device, **kwargs)
72+
with warnings.catch_warnings():
73+
warnings.simplefilter("ignore")
74+
return rep.run(inputs, **kwargs)
75+
76+
@classmethod
77+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
78+
raise NotImplementedError("Unable to run the model node by node.")
79+
80+
81+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
82+
backend_test = onnx.backend.test.BackendTest(
83+
OnnxruntimeEvaluatorBackend,
84+
__name__,
85+
test_kwargs={
86+
"test_dft": {"atol": dft_atol},
87+
"test_dft_axis": {"atol": dft_atol},
88+
"test_dft_axis_opset19": {"atol": dft_atol},
89+
"test_dft_inverse": {"atol": dft_atol},
90+
"test_dft_inverse_opset19": {"atol": dft_atol},
91+
"test_dft_opset19": {"atol": dft_atol},
92+
},
93+
)
94+
95+
96+
# The following tests are too slow with the reference implementation (Conv).
97+
backend_test.exclude(
98+
"(test_bvlc_alexnet"
99+
"|test_densenet121"
100+
"|test_inception_v1"
101+
"|test_inception_v2"
102+
"|test_resnet50"
103+
"|test_shufflenet"
104+
"|test_squeezenet"
105+
"|test_vgg19"
106+
"|test_zfnet512)"
107+
)
108+
109+
# The following tests cannot pass because they consists in generating random number.
110+
backend_test.exclude("(test_bernoulli)")
111+
112+
# The following tests are not supported.
113+
backend_test.exclude(
114+
"(test_gradient"
115+
"|test_if_opt"
116+
"|test_loop16_seq_none"
117+
"|test_range_float_type_positive_delta_expanded"
118+
"|test_range_int32_type_negative_delta_expanded"
119+
"|test_scan_sum)"
120+
)
121+
122+
# The following tests fail due to discrepancies (small but still higher than 1e-7).
123+
backend_test.exclude("test_adam_multiple") # 1e-2
124+
125+
126+
if onnx_opset_version() < 19:
127+
backend_test.exclude(
128+
"(test_argm[ai][nx]_default_axis_example"
129+
"|test_argm[ai][nx]_default_axis_random"
130+
"|test_argm[ai][nx]_keepdims_example"
131+
"|test_argm[ai][nx]_keepdims_random"
132+
"|test_argm[ai][nx]_negative_axis_keepdims_example"
133+
"|test_argm[ai][nx]_negative_axis_keepdims_random"
134+
"|test_argm[ai][nx]_no_keepdims_example"
135+
"|test_argm[ai][nx]_no_keepdims_random"
136+
"|test_col2im_pads"
137+
"|test_gru_batchwise"
138+
"|test_gru_defaults"
139+
"|test_gru_seq_length"
140+
"|test_gru_with_initial_bias"
141+
"|test_layer_normalization_2d_axis1_expanded"
142+
"|test_layer_normalization_2d_axis_negative_1_expanded"
143+
"|test_layer_normalization_3d_axis1_epsilon_expanded"
144+
"|test_layer_normalization_3d_axis2_epsilon_expanded"
145+
"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded"
146+
"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded"
147+
"|test_layer_normalization_4d_axis1_expanded"
148+
"|test_layer_normalization_4d_axis2_expanded"
149+
"|test_layer_normalization_4d_axis3_expanded"
150+
"|test_layer_normalization_4d_axis_negative_1_expanded"
151+
"|test_layer_normalization_4d_axis_negative_2_expanded"
152+
"|test_layer_normalization_4d_axis_negative_3_expanded"
153+
"|test_layer_normalization_default_axis_expanded"
154+
"|test_logsoftmax_large_number_expanded"
155+
"|test_lstm_batchwise"
156+
"|test_lstm_defaults"
157+
"|test_lstm_with_initial_bias"
158+
"|test_lstm_with_peepholes"
159+
"|test_mvn"
160+
"|test_mvn_expanded"
161+
"|test_softmax_large_number_expanded"
162+
"|test_operator_reduced_mean"
163+
"|test_operator_reduced_mean_keepdim)"
164+
)
165+
166+
if onnx_opset_version() < 21:
167+
backend_test.exclude(
168+
"(test_averagepool_2d_dilations"
169+
"|test_if*"
170+
"|test_loop*"
171+
"|test_scan*"
172+
"|test_sequence_map*"
173+
")"
174+
)
175+
176+
# The following tests are using types not supported by NumPy.
177+
# They could be if method to_array is extended to support custom
178+
# types the same as the reference implementation does
179+
# (see onnx.reference.op_run.to_array_extended).
180+
backend_test.exclude(
181+
"(test_cast_FLOAT_to_BFLOAT16"
182+
"|test_cast_BFLOAT16_to_FLOAT"
183+
"|test_cast_BFLOAT16_to_FLOAT"
184+
"|test_castlike_BFLOAT16_to_FLOAT"
185+
"|test_castlike_FLOAT_to_BFLOAT16"
186+
"|test_castlike_FLOAT_to_BFLOAT16_expanded"
187+
"|test_cast_no_saturate_"
188+
"|_to_FLOAT8"
189+
"|_FLOAT8"
190+
"|test_quantizelinear_e4m3fn"
191+
"|test_quantizelinear_e5m2"
192+
")"
193+
)
194+
195+
# Disable test about float 8
196+
backend_test.exclude(
197+
"(test_castlike_BFLOAT16*"
198+
"|test_cast_BFLOAT16*"
199+
"|test_cast_no_saturate*"
200+
"|test_cast_FLOAT_to_FLOAT8*"
201+
"|test_cast_FLOAT16_to_FLOAT8*"
202+
"|test_cast_FLOAT8_to_*"
203+
"|test_castlike_BFLOAT16*"
204+
"|test_castlike_no_saturate*"
205+
"|test_castlike_FLOAT_to_FLOAT8*"
206+
"|test_castlike_FLOAT16_to_FLOAT8*"
207+
"|test_castlike_FLOAT8_to_*"
208+
"|test_quantizelinear_e*)"
209+
)
210+
211+
# Disable test about INT 4
212+
backend_test.exclude(
213+
"(test_cast_FLOAT_to_INT4*"
214+
"|test_cast_FLOAT16_to_INT4*"
215+
"|test_cast_INT4_to_*"
216+
"|test_castlike_INT4_to_*"
217+
"|test_cast_FLOAT_to_UINT4*"
218+
"|test_cast_FLOAT16_to_UINT4*"
219+
"|test_cast_UINT4_to_*"
220+
"|test_castlike_UINT4_to_*)"
221+
)
222+
223+
backend_test.exclude("(test_regex_full_match*)")
224+
225+
backend_test.exclude("(test_scatter_with_axis*|test_scatter_without_axis*)")
226+
227+
if onnx_opset_version() < 21:
228+
# The following tests fail due to a bug in the backend test comparison.
229+
backend_test.exclude(
230+
"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)"
231+
)
232+
233+
# The following tests fail due to a shape mismatch.
234+
backend_test.exclude(
235+
"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)"
236+
)
237+
238+
# The following tests fail due to a type mismatch.
239+
backend_test.exclude("(test_eyelike_without_dtype)")
240+
241+
242+
# import all test cases at global scope to make them visible to python.unittest
243+
globals().update(backend_test.test_cases)
244+
245+
if __name__ == "__main__":
246+
res = unittest.main(verbosity=2, exit=False)
247+
tests_run = res.result.testsRun
248+
errors = len(res.result.errors)
249+
skipped = len(res.result.skipped)
250+
unexpected_successes = len(res.result.unexpectedSuccesses)
251+
expected_failures = len(res.result.expectedFailures)
252+
print("---------------------------------")
253+
print(
254+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
255+
f"unexpected_successes={unexpected_successes} "
256+
f"expected_failures={expected_failures}"
257+
)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from ..helpers import pretty_onnx, dtype_to_tensor_dtype, string_type
1717
from ..ort_session import InferenceSessionForTorch, InferenceSessionForNumpy, _InferenceSession
1818

19+
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
20+
1921

2022
class OnnxruntimeEvaluator:
2123
"""
@@ -41,7 +43,7 @@ class OnnxruntimeEvaluator:
4143

4244
def __init__(
4345
self,
44-
proto: Union[str, ModelProto, FunctionProto, GraphProto, NodeProto],
46+
proto: Union[str, Proto, "OnnxruntimeEvaluator"],
4547
session_options: Optional[onnxruntime.SessionOptions] = None,
4648
providers: Optional[Union[str, List[str]]] = None,
4749
nvtx: bool = False,
@@ -53,7 +55,7 @@ def __init__(
5355
disable_aot_function_inlining: Optional[bool] = None,
5456
use_training_api: Optional[bool] = None,
5557
verbose: int = 0,
56-
local_functions: Optional[Dict[Tuple[str, str], FunctionProto]] = None,
58+
local_functions: Optional[Dict[Tuple[str, str], Proto]] = None,
5759
ir_version: int = 10,
5860
opsets: Optional[Union[int, Dict[str, int]]] = None,
5961
):
@@ -222,7 +224,7 @@ def run(
222224

223225
def _make_model_proto(
224226
self, nodes: Sequence[NodeProto], vinputs: ValueInfoProto, voutputs: ValueInfoProto
225-
):
227+
) -> ModelProto:
226228
onx = oh.make_model(
227229
oh.make_graph(nodes, "-", vinputs, voutputs),
228230
ir_version=getattr(self.proto, "ir_version", self.ir_version),
@@ -279,7 +281,7 @@ def _get_sess(
279281

280282
def _get_sess_if(
281283
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
282-
) -> Tuple[ModelProto, onnxruntime.InferenceSession]:
284+
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
283285
unique_names = set()
284286
vinputs = []
285287
for i, it in zip(node.input, inputs):
@@ -314,7 +316,7 @@ def _get_sess_if(
314316

315317
def _get_sess_local(
316318
self, node: NodeProto, inputs: List[Any]
317-
) -> Tuple[ModelProto, onnxruntime.InferenceSession]:
319+
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
318320
onx = self.local_functions[node.domain, node.op_type]
319321
sess = OnnxruntimeEvaluator(
320322
onx,

0 commit comments

Comments
 (0)