|
| 1 | +""" |
| 2 | +This file runs through the backend test and evaluates onnxruntime. |
| 3 | +""" |
| 4 | + |
| 5 | +import unittest |
| 6 | +import warnings |
| 7 | +from typing import Any |
| 8 | +import numpy |
| 9 | +import onnx.backend.base |
| 10 | +import onnx.backend.test |
| 11 | +import onnx.shape_inference |
| 12 | +import onnx.version_converter |
| 13 | +from onnx import ModelProto |
| 14 | +from onnx.backend.base import Device, DeviceType |
| 15 | +from onnx.defs import onnx_opset_version |
| 16 | +import onnxruntime |
| 17 | + |
| 18 | +ORT_OPSET = max(23, onnx_opset_version() - 2) |
| 19 | + |
| 20 | + |
| 21 | +class OnnxruntimeBackendRep(onnx.backend.base.BackendRep): |
| 22 | + def __init__(self, session): |
| 23 | + self._session = session |
| 24 | + |
| 25 | + def run(self, inputs, **kwargs): |
| 26 | + if isinstance(inputs, numpy.ndarray): |
| 27 | + inputs = [inputs] |
| 28 | + if isinstance(inputs, list): |
| 29 | + if len(inputs) == len(self._session.input_names): |
| 30 | + feeds = dict(zip(self._session.input_names, inputs)) |
| 31 | + else: |
| 32 | + feeds = {} |
| 33 | + pos_inputs = 0 |
| 34 | + for inp, tshape in zip(self._session.input_names, self._session.input_types): |
| 35 | + shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) |
| 36 | + if shape == inputs[pos_inputs].shape: |
| 37 | + feeds[inp] = inputs[pos_inputs] |
| 38 | + pos_inputs += 1 |
| 39 | + if pos_inputs >= len(inputs): |
| 40 | + break |
| 41 | + elif isinstance(inputs, dict): |
| 42 | + feeds = inputs |
| 43 | + else: |
| 44 | + raise TypeError(f"Unexpected input type {type(inputs)!r}.") |
| 45 | + outs = self._session.run(None, feeds) |
| 46 | + return outs |
| 47 | + |
| 48 | + |
| 49 | +class OnnxruntimeBackend(onnx.backend.base.Backend): |
| 50 | + @classmethod |
| 51 | + def is_compatible(cls, model) -> bool: |
| 52 | + return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import) |
| 53 | + |
| 54 | + @classmethod |
| 55 | + def supports_device(cls, device: str) -> bool: |
| 56 | + d = Device(device) |
| 57 | + if d == DeviceType.CPU: |
| 58 | + return True |
| 59 | + if d == DeviceType.GPU: |
| 60 | + import torch |
| 61 | + |
| 62 | + return torch.cuda.is_available() |
| 63 | + return False |
| 64 | + |
| 65 | + @classmethod |
| 66 | + def create_inference_session(cls, model, device): |
| 67 | + d = Device(device) |
| 68 | + if d == DeviceType.GPU: |
| 69 | + providers = ["CUDAExecutionProvider"] |
| 70 | + elif d == DeviceType.CPU: |
| 71 | + providers = ["CPUExecutionProvider"] |
| 72 | + else: |
| 73 | + raise ValueError(f"Unrecognized device {device!r} or {d!r}") |
| 74 | + return onnxruntime.InferenceSession(model.SerializeToString(), providers=providers) |
| 75 | + |
| 76 | + @classmethod |
| 77 | + def prepare(cls, model: Any, device: str = "CPU", **kwargs: Any) -> OnnxruntimeBackendRep: |
| 78 | + if isinstance(model, onnxruntime.InferenceSession): |
| 79 | + return OnnxruntimeBackendRep(model) |
| 80 | + if isinstance(model, (str, bytes, ModelProto)): |
| 81 | + inf = cls.create_inference_session(model, device) |
| 82 | + return cls.prepare(inf, device, **kwargs) |
| 83 | + raise TypeError(f"Unexpected type {type(model)} for model.") |
| 84 | + |
| 85 | + @classmethod |
| 86 | + def run_model(cls, model, inputs, device=None, **kwargs): |
| 87 | + rep = cls.prepare(model, device, **kwargs) |
| 88 | + with warnings.catch_warnings(): |
| 89 | + warnings.simplefilter("ignore") |
| 90 | + return rep.run(inputs, **kwargs) |
| 91 | + |
| 92 | + @classmethod |
| 93 | + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): |
| 94 | + raise NotImplementedError("Unable to run the model node by node.") |
| 95 | + |
| 96 | + |
| 97 | +dft_atol = 1e-3 |
| 98 | +stft_atol = 1e-4 |
| 99 | +ql_atol = 1e-5 |
| 100 | +backend_test = onnx.backend.test.BackendTest( |
| 101 | + OnnxruntimeBackend, |
| 102 | + __name__, |
| 103 | + test_kwargs={ |
| 104 | + "test_dft": {"atol": dft_atol, "rtol": numpy.inf}, |
| 105 | + "test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf}, |
| 106 | + "test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf}, |
| 107 | + "test_dft_inverse": {"atol": dft_atol, "rtol": numpy.inf}, |
| 108 | + "test_dft_inverse_opset19": {"atol": dft_atol, "rtol": numpy.inf}, |
| 109 | + "test_dft_opset19": {"atol": dft_atol, "rtol": numpy.inf}, |
| 110 | + "test_stft": {"atol": stft_atol, "rtol": numpy.inf}, |
| 111 | + "test_stft_with_window": {"atol": stft_atol, "rtol": numpy.inf}, |
| 112 | + "test_qlinearmatmul_2D_int8_float32": {"atol": ql_atol}, |
| 113 | + "test_qlinearmatmul_3D_int8_float32": {"atol": ql_atol}, |
| 114 | + }, |
| 115 | +) |
| 116 | + |
| 117 | +# The following tests are too slow with the reference implementation (Conv). |
| 118 | +backend_test.exclude( |
| 119 | + "(test_bvlc_alexnet" |
| 120 | + "|test_densenet121" |
| 121 | + "|test_inception_v1" |
| 122 | + "|test_inception_v2" |
| 123 | + "|test_resnet50" |
| 124 | + "|test_shufflenet" |
| 125 | + "|test_squeezenet" |
| 126 | + "|test_vgg19" |
| 127 | + "|test_zfnet512)" |
| 128 | +) |
| 129 | + |
| 130 | +# The following tests cannot pass because they consists in generating random number. |
| 131 | +backend_test.exclude("(test_bernoulli|test_PoissonNLLLLoss)") |
| 132 | + |
| 133 | +# The following tests are not supported. |
| 134 | +backend_test.exclude("test_gradient") |
| 135 | + |
| 136 | +backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)") |
| 137 | + |
| 138 | + |
| 139 | +# import all test cases at global scope to make them visible to python.unittest |
| 140 | +globals().update(backend_test.test_cases) |
| 141 | + |
| 142 | +if __name__ == "__main__": |
| 143 | + res = unittest.main(verbosity=2, exit=False) |
| 144 | + tests_run = res.result.testsRun |
| 145 | + errors = len(res.result.errors) |
| 146 | + skipped = len(res.result.skipped) |
| 147 | + unexpected_successes = len(res.result.unexpectedSuccesses) |
| 148 | + expected_failures = len(res.result.expectedFailures) |
| 149 | + print("---------------------------------") |
| 150 | + print( |
| 151 | + f"tests_run={tests_run} errors={errors} skipped={skipped} " |
| 152 | + f"unexpected_successes={unexpected_successes} " |
| 153 | + f"expected_failures={expected_failures}" |
| 154 | + ) |
0 commit comments