Skip to content

Commit 14d9086

Browse files
committed
Add ExtendedReferenceEvaluator
1 parent ee1c214 commit 14d9086

33 files changed

+1706
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5+
from experimental_experiment.ext_test_case import ExtTestCase
6+
from experimental_experiment.reference import (
7+
to_array_extended,
8+
from_array_extended,
9+
ExtendedReferenceEvaluator,
10+
)
11+
12+
13+
class TestArrayTensor(ExtTestCase):
14+
def test_from_array(self):
15+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
16+
with self.subTest(dtype=dt):
17+
a = np.array([0, 1, 2], dtype=dt)
18+
t = from_array_extended(a, "a")
19+
b = to_array_extended(t)
20+
self.assertEqualArray(a, b)
21+
t2 = from_array_extended(b, "a")
22+
self.assertEqual(t.SerializeToString(), t2.SerializeToString())
23+
24+
def test_from_array_f8(self):
25+
def make_model_f8(fr, to):
26+
model = make_model(
27+
make_graph(
28+
[make_node("Cast", ["X"], ["Y"], to=to)],
29+
"cast",
30+
[make_tensor_value_info("X", fr, None)],
31+
[make_tensor_value_info("Y", to, None)],
32+
)
33+
)
34+
return model
35+
36+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
37+
with self.subTest(dtype=dt):
38+
a = np.array([0, 1, 2], dtype=dt)
39+
b = from_array_extended(a, "a")
40+
for to in [
41+
TensorProto.FLOAT8E4M3FN,
42+
TensorProto.FLOAT8E4M3FNUZ,
43+
TensorProto.FLOAT8E5M2,
44+
TensorProto.FLOAT8E5M2FNUZ,
45+
TensorProto.BFLOAT16,
46+
]:
47+
with self.subTest(fr=b.data_type, to=to):
48+
model = make_model_f8(b.data_type, to)
49+
ref = ExtendedReferenceEvaluator(model)
50+
got = ref.run(None, {"X": a})[0]
51+
back = from_array_extended(got, "a")
52+
self.assertEqual(to, back.data_type)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main(verbosity=2)
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import sys
2+
import unittest
3+
import warnings
4+
from typing import Any
5+
import numpy
6+
import onnx.backend.base
7+
import onnx.backend.test
8+
import onnx.shape_inference
9+
import onnx.version_converter
10+
from onnx import ModelProto
11+
from onnx.backend.base import Device, DeviceType
12+
from onnx.defs import onnx_opset_version
13+
from experimental_experiment.reference import ExtendedReferenceEvaluator
14+
15+
16+
class ExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep):
17+
def __init__(self, session):
18+
self._session = session
19+
20+
def run(self, inputs, **kwargs):
21+
if isinstance(inputs, numpy.ndarray):
22+
inputs = [inputs]
23+
if isinstance(inputs, list):
24+
if len(inputs) == len(self._session.input_names):
25+
feeds = dict(zip(self._session.input_names, inputs))
26+
else:
27+
feeds = {}
28+
pos_inputs = 0
29+
for inp, tshape in zip(self._session.input_names, self._session.input_types):
30+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
31+
if shape == inputs[pos_inputs].shape:
32+
feeds[inp] = inputs[pos_inputs]
33+
pos_inputs += 1
34+
if pos_inputs >= len(inputs):
35+
break
36+
elif isinstance(inputs, dict):
37+
feeds = inputs
38+
else:
39+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
40+
outs = self._session.run(None, feeds)
41+
return outs
42+
43+
44+
class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend):
45+
@classmethod
46+
def is_opset_supported(cls, model): # pylint: disable=unused-argument
47+
return True, ""
48+
49+
@classmethod
50+
def supports_device(cls, device: str) -> bool:
51+
d = Device(device)
52+
return d.type == DeviceType.CPU # type: ignore[no-any-return]
53+
54+
@classmethod
55+
def create_inference_session(cls, model):
56+
return ExtendedReferenceEvaluator(model)
57+
58+
@classmethod
59+
def prepare(
60+
cls, model: Any, device: str = "CPU", **kwargs: Any
61+
) -> ExtendedReferenceEvaluatorBackendRep:
62+
if isinstance(model, ExtendedReferenceEvaluator):
63+
return ExtendedReferenceEvaluatorBackendRep(model)
64+
if isinstance(model, (str, bytes, ModelProto)):
65+
inf = cls.create_inference_session(model)
66+
return cls.prepare(inf, device, **kwargs)
67+
raise TypeError(f"Unexpected type {type(model)} for model.")
68+
69+
@classmethod
70+
def run_model(cls, model, inputs, device=None, **kwargs):
71+
rep = cls.prepare(model, device, **kwargs)
72+
with warnings.catch_warnings():
73+
warnings.simplefilter("ignore")
74+
return rep.run(inputs, **kwargs)
75+
76+
@classmethod
77+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
78+
raise NotImplementedError("Unable to run the model node by node.")
79+
80+
81+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
82+
backend_test = onnx.backend.test.BackendTest(
83+
ExtendedReferenceEvaluatorBackend,
84+
__name__,
85+
test_kwargs={
86+
"test_dft": {"atol": dft_atol},
87+
"test_dft_axis": {"atol": dft_atol},
88+
"test_dft_axis_opset19": {"atol": dft_atol},
89+
"test_dft_inverse": {"atol": dft_atol},
90+
"test_dft_inverse_opset19": {"atol": dft_atol},
91+
"test_dft_opset19": {"atol": dft_atol},
92+
},
93+
)
94+
95+
96+
# The following tests are too slow with the reference implementation (Conv).
97+
backend_test.exclude(
98+
"(test_bvlc_alexnet"
99+
"|test_densenet121"
100+
"|test_inception_v1"
101+
"|test_inception_v2"
102+
"|test_resnet50"
103+
"|test_shufflenet"
104+
"|test_squeezenet"
105+
"|test_vgg19"
106+
"|test_zfnet512)"
107+
)
108+
109+
# The following tests cannot pass because they consists in generating random number.
110+
backend_test.exclude("(test_bernoulli)")
111+
112+
# The following tests are not supported.
113+
backend_test.exclude(
114+
"(test_gradient"
115+
"|test_if_opt"
116+
"|test_loop16_seq_none"
117+
"|test_range_float_type_positive_delta_expanded"
118+
"|test_range_int32_type_negative_delta_expanded"
119+
"|test_scan_sum)"
120+
)
121+
122+
# The following tests fail due to discrepancies (small but still higher than 1e-7).
123+
backend_test.exclude("test_adam_multiple") # 1e-2
124+
125+
126+
if onnx_opset_version() < 19:
127+
backend_test.exclude(
128+
"(test_argm[ai][nx]_default_axis_example"
129+
"|test_argm[ai][nx]_default_axis_random"
130+
"|test_argm[ai][nx]_keepdims_example"
131+
"|test_argm[ai][nx]_keepdims_random"
132+
"|test_argm[ai][nx]_negative_axis_keepdims_example"
133+
"|test_argm[ai][nx]_negative_axis_keepdims_random"
134+
"|test_argm[ai][nx]_no_keepdims_example"
135+
"|test_argm[ai][nx]_no_keepdims_random"
136+
"|test_col2im_pads"
137+
"|test_gru_batchwise"
138+
"|test_gru_defaults"
139+
"|test_gru_seq_length"
140+
"|test_gru_with_initial_bias"
141+
"|test_layer_normalization_2d_axis1_expanded"
142+
"|test_layer_normalization_2d_axis_negative_1_expanded"
143+
"|test_layer_normalization_3d_axis1_epsilon_expanded"
144+
"|test_layer_normalization_3d_axis2_epsilon_expanded"
145+
"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded"
146+
"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded"
147+
"|test_layer_normalization_4d_axis1_expanded"
148+
"|test_layer_normalization_4d_axis2_expanded"
149+
"|test_layer_normalization_4d_axis3_expanded"
150+
"|test_layer_normalization_4d_axis_negative_1_expanded"
151+
"|test_layer_normalization_4d_axis_negative_2_expanded"
152+
"|test_layer_normalization_4d_axis_negative_3_expanded"
153+
"|test_layer_normalization_default_axis_expanded"
154+
"|test_logsoftmax_large_number_expanded"
155+
"|test_lstm_batchwise"
156+
"|test_lstm_defaults"
157+
"|test_lstm_with_initial_bias"
158+
"|test_lstm_with_peepholes"
159+
"|test_mvn"
160+
"|test_mvn_expanded"
161+
"|test_softmax_large_number_expanded"
162+
"|test_operator_reduced_mean"
163+
"|test_operator_reduced_mean_keepdim)"
164+
)
165+
166+
if onnx_opset_version() < 21:
167+
backend_test.exclude(
168+
"(test_averagepool_2d_dilations"
169+
"|test_if*"
170+
"|test_loop*"
171+
"|test_scan*"
172+
"|test_sequence_map*"
173+
")"
174+
)
175+
176+
# The following tests are using types not supported by NumPy.
177+
# They could be if method to_array is extended to support custom
178+
# types the same as the reference implementation does
179+
# (see onnx.reference.op_run.to_array_extended).
180+
backend_test.exclude(
181+
"(test_cast_FLOAT_to_BFLOAT16"
182+
"|test_cast_BFLOAT16_to_FLOAT"
183+
"|test_cast_BFLOAT16_to_FLOAT"
184+
"|test_castlike_BFLOAT16_to_FLOAT"
185+
"|test_castlike_FLOAT_to_BFLOAT16"
186+
"|test_castlike_FLOAT_to_BFLOAT16_expanded"
187+
"|test_cast_no_saturate_"
188+
"|_to_FLOAT8"
189+
"|_FLOAT8"
190+
"|test_quantizelinear_e4m3fn"
191+
"|test_quantizelinear_e5m2"
192+
")"
193+
)
194+
195+
# Disable test about float 8
196+
backend_test.exclude(
197+
"(test_castlike_BFLOAT16*"
198+
"|test_cast_BFLOAT16*"
199+
"|test_cast_no_saturate*"
200+
"|test_cast_FLOAT_to_FLOAT8*"
201+
"|test_cast_FLOAT16_to_FLOAT8*"
202+
"|test_cast_FLOAT8_to_*"
203+
"|test_castlike_BFLOAT16*"
204+
"|test_castlike_no_saturate*"
205+
"|test_castlike_FLOAT_to_FLOAT8*"
206+
"|test_castlike_FLOAT16_to_FLOAT8*"
207+
"|test_castlike_FLOAT8_to_*"
208+
"|test_quantizelinear_e*)"
209+
)
210+
211+
# Disable test about INT 4
212+
backend_test.exclude(
213+
"(test_cast_FLOAT_to_INT4*"
214+
"|test_cast_FLOAT16_to_INT4*"
215+
"|test_cast_INT4_to_*"
216+
"|test_castlike_INT4_to_*"
217+
"|test_cast_FLOAT_to_UINT4*"
218+
"|test_cast_FLOAT16_to_UINT4*"
219+
"|test_cast_UINT4_to_*"
220+
"|test_castlike_UINT4_to_*)"
221+
)
222+
223+
backend_test.exclude("(test_regex_full_match*)")
224+
225+
backend_test.exclude("(test_scatter_with_axis*|test_scatter_without_axis*)")
226+
227+
if onnx_opset_version() < 21:
228+
# The following tests fail due to a bug in the backend test comparison.
229+
backend_test.exclude(
230+
"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)"
231+
)
232+
233+
# The following tests fail due to a shape mismatch.
234+
backend_test.exclude(
235+
"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)"
236+
)
237+
238+
# The following tests fail due to a type mismatch.
239+
backend_test.exclude("(test_eyelike_without_dtype)")
240+
241+
242+
# import all test cases at global scope to make them visible to python.unittest
243+
globals().update(backend_test.test_cases)
244+
245+
if __name__ == "__main__":
246+
res = unittest.main(verbosity=2, exit=False)
247+
tests_run = res.result.testsRun
248+
errors = len(res.result.errors)
249+
skipped = len(res.result.skipped)
250+
unexpected_successes = len(res.result.unexpectedSuccesses)
251+
expected_failures = len(res.result.expectedFailures)
252+
print("---------------------------------")
253+
print(
254+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
255+
f"unexpected_successes={unexpected_successes} "
256+
f"expected_failures={expected_failures}"
257+
)

0 commit comments

Comments
 (0)