Skip to content

Commit 4078f6a

Browse files
committed
backend
1 parent e2239a8 commit 4078f6a

File tree

4 files changed

+34
-89
lines changed

4 files changed

+34
-89
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,16 @@ jobs:
5959
UNITTEST_GOING=1 pytest --durations=10 _unittests --ignore _unittests/ut_reference/test_backend_extended_reference_evaluator.py --ignore _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6060
export PYTHONPATH=
6161
62-
- name: run backend tests
62+
- name: run backend tests python
6363
run: |
6464
pip install pytest
6565
export PYTHONPATH=.
66-
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
66+
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_extended_reference_evaluator.py
67+
export PYTHONPATH=
68+
69+
- name: run backend tests onnxruntime
70+
run: |
71+
pip install pytest
72+
export PYTHONPATH=.
73+
UNITTEST_GOING=1 pytest --durations=10 _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
6774
export PYTHONPATH=

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
"GraphModule": "https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule",
178178
"HuggingFace": "https://huggingface.co/docs/hub/en/index",
179179
"Linux": "https://www.linux.org/",
180+
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
180181
"monai": "https://monai.io/",
181182
"numpy": "https://numpy.org/",
182183
"onnx": "https://onnx.ai/onnx/",

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ def run(self, inputs, **kwargs):
4343

4444
class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend):
4545
@classmethod
46-
def is_opset_supported(cls, model): # pylint: disable=unused-argument
47-
return True, ""
46+
def is_compatible(cls, model) -> bool:
47+
return True
4848

4949
@classmethod
5050
def supports_device(cls, device: str) -> bool:
5151
d = Device(device)
52-
return d.type == DeviceType.CPU # type: ignore[no-any-return]
52+
return d.type == DeviceType.CPU
5353

5454
@classmethod
5555
def create_inference_session(cls, model):

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 21 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from onnx.defs import onnx_opset_version
1313
from onnx_diagnostic.reference import OnnxruntimeEvaluator
1414

15+
ORT_OPSET = max(21, onnx_opset_version() - 2)
16+
1517

1618
class OnnxruntimeEvaluatorBackendRep(onnx.backend.base.BackendRep):
1719
def __init__(self, session):
@@ -43,13 +45,13 @@ def run(self, inputs, **kwargs):
4345

4446
class OnnxruntimeEvaluatorBackend(onnx.backend.base.Backend):
4547
@classmethod
46-
def is_opset_supported(cls, model): # pylint: disable=unused-argument
47-
return True, ""
48+
def is_compatible(cls, model) -> bool:
49+
return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import)
4850

4951
@classmethod
5052
def supports_device(cls, device: str) -> bool:
5153
d = Device(device)
52-
return d.type == DeviceType.CPU # type: ignore[no-any-return]
54+
return d.type == DeviceType.CPU
5355

5456
@classmethod
5557
def create_inference_session(cls, model):
@@ -119,79 +121,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
119121
"|test_scan_sum)"
120122
)
121123

122-
# The following tests fail due to discrepancies (small but still higher than 1e-7).
123-
backend_test.exclude("test_adam_multiple") # 1e-2
124-
125-
126-
if onnx_opset_version() < 19:
127-
backend_test.exclude(
128-
"(test_argm[ai][nx]_default_axis_example"
129-
"|test_argm[ai][nx]_default_axis_random"
130-
"|test_argm[ai][nx]_keepdims_example"
131-
"|test_argm[ai][nx]_keepdims_random"
132-
"|test_argm[ai][nx]_negative_axis_keepdims_example"
133-
"|test_argm[ai][nx]_negative_axis_keepdims_random"
134-
"|test_argm[ai][nx]_no_keepdims_example"
135-
"|test_argm[ai][nx]_no_keepdims_random"
136-
"|test_col2im_pads"
137-
"|test_gru_batchwise"
138-
"|test_gru_defaults"
139-
"|test_gru_seq_length"
140-
"|test_gru_with_initial_bias"
141-
"|test_layer_normalization_2d_axis1_expanded"
142-
"|test_layer_normalization_2d_axis_negative_1_expanded"
143-
"|test_layer_normalization_3d_axis1_epsilon_expanded"
144-
"|test_layer_normalization_3d_axis2_epsilon_expanded"
145-
"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded"
146-
"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded"
147-
"|test_layer_normalization_4d_axis1_expanded"
148-
"|test_layer_normalization_4d_axis2_expanded"
149-
"|test_layer_normalization_4d_axis3_expanded"
150-
"|test_layer_normalization_4d_axis_negative_1_expanded"
151-
"|test_layer_normalization_4d_axis_negative_2_expanded"
152-
"|test_layer_normalization_4d_axis_negative_3_expanded"
153-
"|test_layer_normalization_default_axis_expanded"
154-
"|test_logsoftmax_large_number_expanded"
155-
"|test_lstm_batchwise"
156-
"|test_lstm_defaults"
157-
"|test_lstm_with_initial_bias"
158-
"|test_lstm_with_peepholes"
159-
"|test_mvn"
160-
"|test_mvn_expanded"
161-
"|test_softmax_large_number_expanded"
162-
"|test_operator_reduced_mean"
163-
"|test_operator_reduced_mean_keepdim)"
164-
)
165-
166124
if onnx_opset_version() < 21:
167125
backend_test.exclude(
168126
"(test_averagepool_2d_dilations"
169127
"|test_if*"
170128
"|test_loop*"
171129
"|test_scan*"
172130
"|test_sequence_map*"
173-
")"
131+
"|test_cast_FLOAT_to_STRING|"
132+
"test_castlike_FLOAT_to_STRING|test_strnorm|"
133+
"test_center_crop_pad_crop_axes_hwc_expanded|"
134+
"test_lppool_2d_dilations|test_eyelike_without_dtype)"
174135
)
175136

176-
# The following tests are using types not supported by NumPy.
177-
# They could be if method to_array is extended to support custom
178-
# types the same as the reference implementation does
179-
# (see onnx.reference.op_run.to_array_extended).
180-
backend_test.exclude(
181-
"(test_cast_FLOAT_to_BFLOAT16"
182-
"|test_cast_BFLOAT16_to_FLOAT"
183-
"|test_cast_BFLOAT16_to_FLOAT"
184-
"|test_castlike_BFLOAT16_to_FLOAT"
185-
"|test_castlike_FLOAT_to_BFLOAT16"
186-
"|test_castlike_FLOAT_to_BFLOAT16_expanded"
187-
"|test_cast_no_saturate_"
188-
"|_to_FLOAT8"
189-
"|_FLOAT8"
190-
"|test_quantizelinear_e4m3fn"
191-
"|test_quantizelinear_e5m2"
192-
")"
193-
)
194-
195137
# Disable test about float 8
196138
backend_test.exclude(
197139
"(test_castlike_BFLOAT16*"
@@ -220,23 +162,18 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
220162
"|test_castlike_UINT4_to_*)"
221163
)
222164

223-
backend_test.exclude("(test_regex_full_match*)")
224-
225-
backend_test.exclude("(test_scatter_with_axis*|test_scatter_without_axis*)")
226-
227-
if onnx_opset_version() < 21:
228-
# The following tests fail due to a bug in the backend test comparison.
229-
backend_test.exclude(
230-
"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)"
231-
)
232-
233-
# The following tests fail due to a shape mismatch.
234-
backend_test.exclude(
235-
"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)"
236-
)
237-
238-
# The following tests fail due to a type mismatch.
239-
backend_test.exclude("(test_eyelike_without_dtype)")
165+
backend_test.exclude(
166+
"(test_regex_full_match*|"
167+
"test_adagrad*|"
168+
"test_adam|"
169+
"test_add_uint8|"
170+
"test_ai_onnx_ml_label_encoder_string*|"
171+
"test_ai_onnx_ml_label_encoder_tensor_mapping*|"
172+
"test_ai_onnx_ml_label_encoder_tensor_value_only_mapping*|"
173+
"test_bitshift_left_uint16*|"
174+
"test_scatter_with_axis*|"
175+
"test_scatter_without_axis*)"
176+
)
240177

241178

242179
# import all test cases at global scope to make them visible to python.unittest

0 commit comments

Comments
 (0)