Skip to content

Commit cbc7b0f

Browse files
committed
fix
1 parent 4a481cf commit cbc7b0f

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import warnings
33
from typing import Any
4+
import packaging.version as pv
45
import numpy
56
import onnx.backend.base
67
import onnx.backend.test
@@ -9,6 +10,7 @@
910
from onnx import ModelProto
1011
from onnx.backend.base import Device, DeviceType
1112
from onnx.defs import onnx_opset_version
13+
import onnxruntime
1214
from onnx_diagnostic.reference import OnnxruntimeEvaluator
1315

1416
ORT_OPSET = max(21, onnx_opset_version() - 2)
@@ -95,10 +97,12 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
9597
dft_atol = 1e-3
9698
stft_atol = 1e-4
9799
ql_atol = 1e-5
100+
fp16_atol = 1e-3
98101
backend_test = onnx.backend.test.BackendTest(
99102
OnnxruntimeEvaluatorBackend,
100103
__name__,
101104
test_kwargs={
105+
"test_attention_4d_fp16": {"atol": fp16_atol},
102106
"test_dft": {"atol": dft_atol, "rtol": numpy.inf},
103107
"test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf},
104108
"test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf},
@@ -287,6 +291,9 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
287291
)
288292
backend_test.exclude(f"({exc})")
289293

294+
if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"):
295+
backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)")
296+
290297
# import all test cases at global scope to make them visible to python.unittest
291298
globals().update(backend_test.test_cases)
292299

0 commit comments

Comments
 (0)