Skip to content

Commit f1ab122

Browse files
committed
Add script to run backend test with onnxruntime
1 parent f7dd78e commit f1ab122

File tree

2 files changed

+171
-4
lines changed

2 files changed

+171
-4
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
This file runs through the backend test and evaluates onnxruntime.
3+
"""
4+
5+
import unittest
6+
import warnings
7+
from typing import Any
8+
import numpy
9+
import onnx.backend.base
10+
import onnx.backend.test
11+
import onnx.shape_inference
12+
import onnx.version_converter
13+
from onnx import ModelProto
14+
from onnx.backend.base import Device, DeviceType
15+
from onnx.defs import onnx_opset_version
16+
import onnxruntime
17+
18+
ORT_OPSET = max(23, onnx_opset_version() - 2)
19+
20+
21+
class OnnxruntimeBackendRep(onnx.backend.base.BackendRep):
22+
def __init__(self, session):
23+
self._session = session
24+
25+
def run(self, inputs, **kwargs):
26+
if isinstance(inputs, numpy.ndarray):
27+
inputs = [inputs]
28+
if isinstance(inputs, list):
29+
if len(inputs) == len(self._session.input_names):
30+
feeds = dict(zip(self._session.input_names, inputs))
31+
else:
32+
feeds = {}
33+
pos_inputs = 0
34+
for inp, tshape in zip(self._session.input_names, self._session.input_types):
35+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
36+
if shape == inputs[pos_inputs].shape:
37+
feeds[inp] = inputs[pos_inputs]
38+
pos_inputs += 1
39+
if pos_inputs >= len(inputs):
40+
break
41+
elif isinstance(inputs, dict):
42+
feeds = inputs
43+
else:
44+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
45+
outs = self._session.run(None, feeds)
46+
return outs
47+
48+
49+
class OnnxruntimeBackend(onnx.backend.base.Backend):
50+
@classmethod
51+
def is_compatible(cls, model) -> bool:
52+
return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import)
53+
54+
@classmethod
55+
def supports_device(cls, device: str) -> bool:
56+
d = Device(device)
57+
if d == DeviceType.CPU:
58+
return True
59+
if d == DeviceType.GPU:
60+
import torch
61+
62+
return torch.cuda.is_available()
63+
return False
64+
65+
@classmethod
66+
def create_inference_session(cls, model, device):
67+
d = Device(device)
68+
if d == DeviceType.GPU:
69+
providers = ["CUDAExecutionProvider"]
70+
elif d == DeviceType.CPU:
71+
providers = ["CPUExecutionProvider"]
72+
else:
73+
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
74+
return onnxruntime.InferenceSession(model.SerializeToString(), providers=providers)
75+
76+
@classmethod
77+
def prepare(cls, model: Any, device: str = "CPU", **kwargs: Any) -> OnnxruntimeBackendRep:
78+
if isinstance(model, onnxruntime.InferenceSession):
79+
return OnnxruntimeBackendRep(model)
80+
if isinstance(model, (str, bytes, ModelProto)):
81+
inf = cls.create_inference_session(model, device)
82+
return cls.prepare(inf, device, **kwargs)
83+
raise TypeError(f"Unexpected type {type(model)} for model.")
84+
85+
@classmethod
86+
def run_model(cls, model, inputs, device=None, **kwargs):
87+
rep = cls.prepare(model, device, **kwargs)
88+
with warnings.catch_warnings():
89+
warnings.simplefilter("ignore")
90+
return rep.run(inputs, **kwargs)
91+
92+
@classmethod
93+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
94+
raise NotImplementedError("Unable to run the model node by node.")
95+
96+
97+
dft_atol = 1e-3
98+
stft_atol = 1e-4
99+
ql_atol = 1e-5
100+
backend_test = onnx.backend.test.BackendTest(
101+
OnnxruntimeBackend,
102+
__name__,
103+
test_kwargs={
104+
"test_dft": {"atol": dft_atol, "rtol": numpy.inf},
105+
"test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf},
106+
"test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf},
107+
"test_dft_inverse": {"atol": dft_atol, "rtol": numpy.inf},
108+
"test_dft_inverse_opset19": {"atol": dft_atol, "rtol": numpy.inf},
109+
"test_dft_opset19": {"atol": dft_atol, "rtol": numpy.inf},
110+
"test_stft": {"atol": stft_atol, "rtol": numpy.inf},
111+
"test_stft_with_window": {"atol": stft_atol, "rtol": numpy.inf},
112+
"test_qlinearmatmul_2D_int8_float32": {"atol": ql_atol},
113+
"test_qlinearmatmul_3D_int8_float32": {"atol": ql_atol},
114+
},
115+
)
116+
117+
# The following tests are too slow with the reference implementation (Conv).
118+
backend_test.exclude(
119+
"(test_bvlc_alexnet"
120+
"|test_densenet121"
121+
"|test_inception_v1"
122+
"|test_inception_v2"
123+
"|test_resnet50"
124+
"|test_shufflenet"
125+
"|test_squeezenet"
126+
"|test_vgg19"
127+
"|test_zfnet512)"
128+
)
129+
130+
# The following tests cannot pass because they consists in generating random number.
131+
backend_test.exclude("(test_bernoulli|test_PoissonNLLLLoss)")
132+
133+
# The following tests are not supported.
134+
backend_test.exclude("test_gradient")
135+
136+
backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)")
137+
138+
139+
# import all test cases at global scope to make them visible to python.unittest
140+
globals().update(backend_test.test_cases)
141+
142+
if __name__ == "__main__":
143+
res = unittest.main(verbosity=2, exit=False)
144+
tests_run = res.result.testsRun
145+
errors = len(res.result.errors)
146+
skipped = len(res.result.skipped)
147+
unexpected_successes = len(res.result.unexpectedSuccesses)
148+
expected_failures = len(res.result.expectedFailures)
149+
print("---------------------------------")
150+
print(
151+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
152+
f"unexpected_successes={unexpected_successes} "
153+
f"expected_failures={expected_failures}"
154+
)

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,24 @@ def is_compatible(cls, model) -> bool:
5050
@classmethod
5151
def supports_device(cls, device: str) -> bool:
5252
d = Device(device)
53-
return d.type == DeviceType.CPU
53+
if d == DeviceType.CPU:
54+
return True
55+
if d == DeviceType.GPU:
56+
import torch
57+
58+
return torch.cuda.is_available()
59+
return False
5460

5561
@classmethod
56-
def create_inference_session(cls, model):
57-
return OnnxruntimeEvaluator(model)
62+
def create_inference_session(cls, model, device):
63+
d = Device(device)
64+
if d == DeviceType.GPU:
65+
providers = ["CUDAExecutionProvider"]
66+
elif d == DeviceType.CPU:
67+
providers = ["CPUExecutionProvider"]
68+
else:
69+
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
70+
return OnnxruntimeEvaluator(model, providers=providers)
5871

5972
@classmethod
6073
def prepare(
@@ -63,7 +76,7 @@ def prepare(
6376
if isinstance(model, OnnxruntimeEvaluator):
6477
return OnnxruntimeEvaluatorBackendRep(model)
6578
if isinstance(model, (str, bytes, ModelProto)):
66-
inf = cls.create_inference_session(model)
79+
inf = cls.create_inference_session(model, device)
6780
return cls.prepare(inf, device, **kwargs)
6881
raise TypeError(f"Unexpected type {type(model)} for model.")
6982

0 commit comments

Comments
 (0)