|
12 | 12 | from onnx.defs import onnx_opset_version |
13 | 13 | from onnx_diagnostic.reference import OnnxruntimeEvaluator |
14 | 14 |
|
| 15 | +ORT_OPSET = max(21, onnx_opset_version() - 2) |
| 16 | + |
15 | 17 |
|
16 | 18 | class OnnxruntimeEvaluatorBackendRep(onnx.backend.base.BackendRep): |
17 | 19 | def __init__(self, session): |
@@ -43,13 +45,13 @@ def run(self, inputs, **kwargs): |
43 | 45 |
|
44 | 46 | class OnnxruntimeEvaluatorBackend(onnx.backend.base.Backend): |
45 | 47 | @classmethod |
46 | | - def is_opset_supported(cls, model): # pylint: disable=unused-argument |
47 | | - return True, "" |
| 48 | + def is_compatible(cls, model) -> bool: |
| 49 | + return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import) |
48 | 50 |
|
49 | 51 | @classmethod |
50 | 52 | def supports_device(cls, device: str) -> bool: |
51 | 53 | d = Device(device) |
52 | | - return d.type == DeviceType.CPU # type: ignore[no-any-return] |
| 54 | + return d.type == DeviceType.CPU |
53 | 55 |
|
54 | 56 | @classmethod |
55 | 57 | def create_inference_session(cls, model): |
@@ -119,79 +121,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): |
119 | 121 | "|test_scan_sum)" |
120 | 122 | ) |
121 | 123 |
|
122 | | -# The following tests fail due to discrepancies (small but still higher than 1e-7). |
123 | | -backend_test.exclude("test_adam_multiple") # 1e-2 |
124 | | - |
125 | | - |
126 | | -if onnx_opset_version() < 19: |
127 | | - backend_test.exclude( |
128 | | - "(test_argm[ai][nx]_default_axis_example" |
129 | | - "|test_argm[ai][nx]_default_axis_random" |
130 | | - "|test_argm[ai][nx]_keepdims_example" |
131 | | - "|test_argm[ai][nx]_keepdims_random" |
132 | | - "|test_argm[ai][nx]_negative_axis_keepdims_example" |
133 | | - "|test_argm[ai][nx]_negative_axis_keepdims_random" |
134 | | - "|test_argm[ai][nx]_no_keepdims_example" |
135 | | - "|test_argm[ai][nx]_no_keepdims_random" |
136 | | - "|test_col2im_pads" |
137 | | - "|test_gru_batchwise" |
138 | | - "|test_gru_defaults" |
139 | | - "|test_gru_seq_length" |
140 | | - "|test_gru_with_initial_bias" |
141 | | - "|test_layer_normalization_2d_axis1_expanded" |
142 | | - "|test_layer_normalization_2d_axis_negative_1_expanded" |
143 | | - "|test_layer_normalization_3d_axis1_epsilon_expanded" |
144 | | - "|test_layer_normalization_3d_axis2_epsilon_expanded" |
145 | | - "|test_layer_normalization_3d_axis_negative_1_epsilon_expanded" |
146 | | - "|test_layer_normalization_3d_axis_negative_2_epsilon_expanded" |
147 | | - "|test_layer_normalization_4d_axis1_expanded" |
148 | | - "|test_layer_normalization_4d_axis2_expanded" |
149 | | - "|test_layer_normalization_4d_axis3_expanded" |
150 | | - "|test_layer_normalization_4d_axis_negative_1_expanded" |
151 | | - "|test_layer_normalization_4d_axis_negative_2_expanded" |
152 | | - "|test_layer_normalization_4d_axis_negative_3_expanded" |
153 | | - "|test_layer_normalization_default_axis_expanded" |
154 | | - "|test_logsoftmax_large_number_expanded" |
155 | | - "|test_lstm_batchwise" |
156 | | - "|test_lstm_defaults" |
157 | | - "|test_lstm_with_initial_bias" |
158 | | - "|test_lstm_with_peepholes" |
159 | | - "|test_mvn" |
160 | | - "|test_mvn_expanded" |
161 | | - "|test_softmax_large_number_expanded" |
162 | | - "|test_operator_reduced_mean" |
163 | | - "|test_operator_reduced_mean_keepdim)" |
164 | | - ) |
165 | | - |
166 | 124 | if onnx_opset_version() < 21: |
167 | 125 | backend_test.exclude( |
168 | 126 | "(test_averagepool_2d_dilations" |
169 | 127 | "|test_if*" |
170 | 128 | "|test_loop*" |
171 | 129 | "|test_scan*" |
172 | 130 | "|test_sequence_map*" |
173 | | - ")" |
| 131 | + "|test_cast_FLOAT_to_STRING|" |
| 132 | + "test_castlike_FLOAT_to_STRING|test_strnorm|" |
| 133 | + "test_center_crop_pad_crop_axes_hwc_expanded|" |
| 134 | + "test_lppool_2d_dilations|test_eyelike_without_dtype)" |
174 | 135 | ) |
175 | 136 |
|
176 | | -# The following tests are using types not supported by NumPy. |
177 | | -# They could be if method to_array is extended to support custom |
178 | | -# types the same as the reference implementation does |
179 | | -# (see onnx.reference.op_run.to_array_extended). |
180 | | -backend_test.exclude( |
181 | | - "(test_cast_FLOAT_to_BFLOAT16" |
182 | | - "|test_cast_BFLOAT16_to_FLOAT" |
183 | | - "|test_cast_BFLOAT16_to_FLOAT" |
184 | | - "|test_castlike_BFLOAT16_to_FLOAT" |
185 | | - "|test_castlike_FLOAT_to_BFLOAT16" |
186 | | - "|test_castlike_FLOAT_to_BFLOAT16_expanded" |
187 | | - "|test_cast_no_saturate_" |
188 | | - "|_to_FLOAT8" |
189 | | - "|_FLOAT8" |
190 | | - "|test_quantizelinear_e4m3fn" |
191 | | - "|test_quantizelinear_e5m2" |
192 | | - ")" |
193 | | -) |
194 | | - |
195 | 137 | # Disable test about float 8 |
196 | 138 | backend_test.exclude( |
197 | 139 | "(test_castlike_BFLOAT16*" |
@@ -220,23 +162,18 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): |
220 | 162 | "|test_castlike_UINT4_to_*)" |
221 | 163 | ) |
222 | 164 |
|
223 | | -backend_test.exclude("(test_regex_full_match*)") |
224 | | - |
225 | | -backend_test.exclude("(test_scatter_with_axis*|test_scatter_without_axis*)") |
226 | | - |
227 | | -if onnx_opset_version() < 21: |
228 | | - # The following tests fail due to a bug in the backend test comparison. |
229 | | - backend_test.exclude( |
230 | | - "(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)" |
231 | | - ) |
232 | | - |
233 | | - # The following tests fail due to a shape mismatch. |
234 | | - backend_test.exclude( |
235 | | - "(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)" |
236 | | - ) |
237 | | - |
238 | | - # The following tests fail due to a type mismatch. |
239 | | - backend_test.exclude("(test_eyelike_without_dtype)") |
| 165 | +backend_test.exclude( |
| 166 | + "(test_regex_full_match*|" |
| 167 | + "test_adagrad*|" |
| 168 | + "test_adam|" |
| 169 | + "test_add_uint8|" |
| 170 | + "test_ai_onnx_ml_label_encoder_string*|" |
| 171 | + "test_ai_onnx_ml_label_encoder_tensor_mapping*|" |
| 172 | + "test_ai_onnx_ml_label_encoder_tensor_value_only_mapping*|" |
| 173 | + "test_bitshift_left_uint16*|" |
| 174 | + "test_scatter_with_axis*|" |
| 175 | + "test_scatter_without_axis*)" |
| 176 | +) |
240 | 177 |
|
241 | 178 |
|
242 | 179 | # import all test cases at global scope to make them visible to python.unittest |
|
0 commit comments