|
1 | 1 | import unittest |
2 | 2 | import warnings |
3 | 3 | from typing import Any |
| 4 | +import packaging.version as pv |
4 | 5 | import numpy |
5 | 6 | import onnx.backend.base |
6 | 7 | import onnx.backend.test |
|
9 | 10 | from onnx import ModelProto |
10 | 11 | from onnx.backend.base import Device, DeviceType |
11 | 12 | from onnx.defs import onnx_opset_version |
| 13 | +import onnxruntime |
12 | 14 | from onnx_diagnostic.reference import OnnxruntimeEvaluator |
13 | 15 |
|
14 | 16 | ORT_OPSET = max(21, onnx_opset_version() - 2) |
@@ -95,10 +97,12 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): |
95 | 97 | dft_atol = 1e-3 |
96 | 98 | stft_atol = 1e-4 |
97 | 99 | ql_atol = 1e-5 |
| 100 | +fp16_atol = 1e-3 |
98 | 101 | backend_test = onnx.backend.test.BackendTest( |
99 | 102 | OnnxruntimeEvaluatorBackend, |
100 | 103 | __name__, |
101 | 104 | test_kwargs={ |
| 105 | + "test_attention_4d_fp16": {"atol": fp16_atol}, |
102 | 106 | "test_dft": {"atol": dft_atol, "rtol": numpy.inf}, |
103 | 107 | "test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf}, |
104 | 108 | "test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf}, |
@@ -287,6 +291,9 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): |
287 | 291 | ) |
288 | 292 | backend_test.exclude(f"({exc})") |
289 | 293 |
|
| 294 | +if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"): |
| 295 | + backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)") |
| 296 | + |
290 | 297 | # import all test cases at global scope to make them visible to python.unittest |
291 | 298 | globals().update(backend_test.test_cases) |
292 | 299 |
|
|
0 commit comments