Skip to content

Commit c2e416a

Browse files
committed
fix a couple of unittests
1 parent 6599db9 commit c2e416a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

_unittests/ut_xrun_doc/test_check_ort_float16.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ExtTestCase,
1111
ignore_warnings,
1212
requires_cuda,
13+
requires_onnxruntime,
1314
)
1415

1516

@@ -130,6 +131,7 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
130131

131132
@requires_cuda()
132133
@ignore_warnings(DeprecationWarning)
134+
@requires_onnxruntime("1.23")
133135
def test_scatterels_cuda(self):
134136
default_value = [
135137
"Cast",
@@ -143,6 +145,10 @@ def test_scatterels_cuda(self):
143145
(np.float16, "none"): default_value,
144146
(np.float32, "add"): default_value,
145147
(np.float16, "add"): default_value,
148+
(np.float32, "min"): default_value,
149+
(np.float16, "min"): default_value,
150+
(np.float32, "max"): default_value,
151+
(np.float16, "max"): default_value,
146152
}
147153
for opset, dtype, reduction in itertools.product(
148154
[16, 18], [np.float32, np.float16], ["none", "add", "min", "max"]
@@ -185,14 +191,14 @@ def test_scatternd_cuda(self):
185191
)
186192

187193
@ignore_warnings(DeprecationWarning)
194+
@requires_onnxruntime("1.23")
188195
def test_scatterels_cpu(self):
189196
default_value = [
190197
"Cast",
191198
"ScatterElements",
192199
"Sub",
193200
]
194201
default_value_16 = [
195-
"Cast",
196202
"Cast",
197203
"ScatterElements",
198204
"Cast",
@@ -218,14 +224,14 @@ def test_scatterels_cpu(self):
218224
)
219225

220226
@ignore_warnings(DeprecationWarning)
227+
@requires_onnxruntime("1.23")
221228
def test_scatternd_cpu(self):
222229
default_value = [
223230
"Cast",
224231
"ScatterND",
225232
"Sub",
226233
]
227234
default_value_16 = [
228-
"Cast",
229235
"Cast",
230236
"ScatterND",
231237
"Cast",

0 commit comments

Comments
 (0)