Skip to content

Commit 2ae830c

Browse files
authored
Add script to run backend test with onnxruntime (#177)
* Add script to run backend test with onnxruntime * delay import * fix dynamic shapes * fix wrong context * dupates * fix ut * fix issues
1 parent f7dd78e commit 2ae830c

File tree

13 files changed

+420
-32
lines changed

13 files changed

+420
-32
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
Change Logs
22
===========
33

4+
0.7.4
5+
+++++
6+
7+
* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
8+
49
0.7.3
510
+++++
611

712
* :pr:`173`: fixes function to_any for BaseModelOutput
813

9-
1014
0.7.2
1115
+++++
1216

_doc/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ onnx-diagnostic: investigate onnx models
2121
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
2222
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
2323
Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-diagnostic/>`_.
24-
Patches can be enabled as follows:
24+
Patches can be enabled as follows with function
25+
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`:
2526

2627
.. code-block:: python
2728
@@ -31,7 +32,8 @@ Patches can be enabled as follows:
3132
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
3233
# ...
3334
34-
Dynamic shapes are difficult to guess for caches, one function
35+
Dynamic shapes are difficult to guess for caches, function
36+
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
3537
returns a structure defining all dimensions as dynamic.
3638
You need then to remove those which are not dynamic in your model.
3739

@@ -237,7 +239,7 @@ The function replaces dynamic dimensions defined as strings by
237239
Older versions
238240
==============
239241

240-
* `0.7.3 <../v0.7.3/index.html>`_
242+
* `0.7.4 <../v0.7.4/index.html>`_
241243
* `0.6.3 <../v0.6.3/index.html>`_
242244
* `0.5.0 <../v0.5.0/index.html>`_
243245
* `0.4.4 <../v0.4.4/index.html>`_

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,21 @@ def flatten_unflatten_like_dynamic_shapes(obj):
7575
value = flatten_unflatten_like_dynamic_shapes(value)
7676
subtrees.append(value)
7777
start = end
78-
if spec.type is dict or spec.context:
78+
if spec.type is dict:
79+
# This a dictionary.
7980
return dict(zip(spec.context, subtrees))
8081
if spec.type is tuple:
8182
return tuple(subtrees)
82-
return subtrees
83+
if spec.type is list:
84+
return list(subtrees)
85+
if spec.context:
86+
# This is a custom class with attributes.
87+
# It is returned as a list.
88+
return list(subtrees)
89+
raise ValueError(
90+
f"Unable to interpret spec type {spec.type} "
91+
f"(type is {type(spec.type)}, context is {spec.context})."
92+
)
8393

8494

8595
def _align(inputs, ds):
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
This file runs through the backend test and evaluates onnxruntime.
3+
"""
4+
5+
import unittest
6+
import warnings
7+
from typing import Any
8+
import numpy
9+
import onnx.backend.base
10+
import onnx.backend.test
11+
import onnx.shape_inference
12+
import onnx.version_converter
13+
from onnx import ModelProto
14+
from onnx.backend.base import Device, DeviceType
15+
from onnx.defs import onnx_opset_version
16+
import onnxruntime
17+
18+
ORT_OPSET = max(23, onnx_opset_version() - 2)
19+
20+
21+
class OnnxruntimeBackendRep(onnx.backend.base.BackendRep):
22+
def __init__(self, session):
23+
self._session = session
24+
25+
def run(self, inputs, **kwargs):
26+
if isinstance(inputs, numpy.ndarray):
27+
inputs = [inputs]
28+
if isinstance(inputs, list):
29+
if len(inputs) == len(self._session.input_names):
30+
feeds = dict(zip(self._session.input_names, inputs))
31+
else:
32+
feeds = {}
33+
pos_inputs = 0
34+
for inp, tshape in zip(self._session.input_names, self._session.input_types):
35+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
36+
if shape == inputs[pos_inputs].shape:
37+
feeds[inp] = inputs[pos_inputs]
38+
pos_inputs += 1
39+
if pos_inputs >= len(inputs):
40+
break
41+
elif isinstance(inputs, dict):
42+
feeds = inputs
43+
else:
44+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
45+
outs = self._session.run(None, feeds)
46+
return outs
47+
48+
49+
class OnnxruntimeBackend(onnx.backend.base.Backend):
50+
@classmethod
51+
def is_compatible(cls, model) -> bool:
52+
return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import)
53+
54+
@classmethod
55+
def supports_device(cls, device: str) -> bool:
56+
d = Device(device)
57+
if d == DeviceType.CPU:
58+
return True
59+
if d == DeviceType.CUDA:
60+
import torch
61+
62+
return torch.cuda.is_available()
63+
return False
64+
65+
@classmethod
66+
def create_inference_session(cls, model, device):
67+
d = Device(device)
68+
if d == DeviceType.CUDA:
69+
providers = ["CUDAExecutionProvider"]
70+
elif d == DeviceType.CPU:
71+
providers = ["CPUExecutionProvider"]
72+
else:
73+
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
74+
return onnxruntime.InferenceSession(model.SerializeToString(), providers=providers)
75+
76+
@classmethod
77+
def prepare(cls, model: Any, device: str = "CPU", **kwargs: Any) -> OnnxruntimeBackendRep:
78+
if isinstance(model, onnxruntime.InferenceSession):
79+
return OnnxruntimeBackendRep(model)
80+
if isinstance(model, (str, bytes, ModelProto)):
81+
inf = cls.create_inference_session(model, device)
82+
return cls.prepare(inf, device, **kwargs)
83+
raise TypeError(f"Unexpected type {type(model)} for model.")
84+
85+
@classmethod
86+
def run_model(cls, model, inputs, device=None, **kwargs):
87+
rep = cls.prepare(model, device, **kwargs)
88+
with warnings.catch_warnings():
89+
warnings.simplefilter("ignore")
90+
return rep.run(inputs, **kwargs)
91+
92+
@classmethod
93+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
94+
raise NotImplementedError("Unable to run the model node by node.")
95+
96+
97+
dft_atol = 1e-3
98+
stft_atol = 1e-4
99+
ql_atol = 1e-5
100+
backend_test = onnx.backend.test.BackendTest(
101+
OnnxruntimeBackend,
102+
__name__,
103+
test_kwargs={
104+
"test_dft": {"atol": dft_atol, "rtol": numpy.inf},
105+
"test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf},
106+
"test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf},
107+
"test_dft_inverse": {"atol": dft_atol, "rtol": numpy.inf},
108+
"test_dft_inverse_opset19": {"atol": dft_atol, "rtol": numpy.inf},
109+
"test_dft_opset19": {"atol": dft_atol, "rtol": numpy.inf},
110+
"test_stft": {"atol": stft_atol, "rtol": numpy.inf},
111+
"test_stft_with_window": {"atol": stft_atol, "rtol": numpy.inf},
112+
"test_qlinearmatmul_2D_int8_float32": {"atol": ql_atol},
113+
"test_qlinearmatmul_3D_int8_float32": {"atol": ql_atol},
114+
},
115+
)
116+
117+
# The following tests are too slow with the reference implementation (Conv).
118+
backend_test.exclude(
119+
"(test_bvlc_alexnet"
120+
"|test_densenet121"
121+
"|test_inception_v1"
122+
"|test_inception_v2"
123+
"|test_resnet50"
124+
"|test_shufflenet"
125+
"|test_squeezenet"
126+
"|test_vgg19"
127+
"|test_zfnet512)"
128+
)
129+
130+
# The following tests cannot pass because they consists in generating random number.
131+
backend_test.exclude("(test_bernoulli|test_PoissonNLLLLoss)")
132+
133+
# The following tests are not supported.
134+
backend_test.exclude("test_gradient")
135+
136+
backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)")
137+
138+
139+
# import all test cases at global scope to make them visible to python.unittest
140+
globals().update(backend_test.test_cases)
141+
142+
if __name__ == "__main__":
143+
res = unittest.main(verbosity=2, exit=False)
144+
tests_run = res.result.testsRun
145+
errors = len(res.result.errors)
146+
skipped = len(res.result.skipped)
147+
unexpected_successes = len(res.result.unexpectedSuccesses)
148+
expected_failures = len(res.result.expectedFailures)
149+
print("---------------------------------")
150+
print(
151+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
152+
f"unexpected_successes={unexpected_successes} "
153+
f"expected_failures={expected_failures}"
154+
)

_unittests/ut_export/test_shape_helper.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,126 @@
55
all_dynamic_shape_from_inputs,
66
guess_dynamic_shapes_from_inputs,
77
)
8+
from onnx_diagnostic.helpers.cache_helper import (
9+
make_dynamic_cache,
10+
make_sliding_window_cache,
11+
make_encoder_decoder_cache,
12+
make_static_cache,
13+
)
814
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
15+
from onnx_diagnostic.torch_export_patches import torch_export_patches
916

1017

1118
class TestShapeHelper(ExtTestCase):
19+
20+
@requires_transformers("4.52")
21+
@requires_torch("2.7.99")
22+
def test_all_dynamic_shape_from_cache(self):
23+
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
24+
ds = all_dynamic_shape_from_inputs(cache)
25+
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)
26+
27+
@requires_torch("2.7.99")
28+
def test_all_dynamic_shape_all_transformers_cache(self):
29+
caches = [
30+
(
31+
make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]),
32+
[[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]],
33+
),
34+
(
35+
make_encoder_decoder_cache(
36+
make_dynamic_cache(
37+
[
38+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
39+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
40+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
41+
]
42+
),
43+
make_dynamic_cache(
44+
[
45+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
46+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
47+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
48+
]
49+
),
50+
),
51+
[
52+
[
53+
[
54+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2"},
55+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2"},
56+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2"},
57+
],
58+
[
59+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2"},
60+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2"},
61+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2"},
62+
],
63+
],
64+
[
65+
[
66+
{0: "d_6_0", 1: "d_6_1", 2: "d_6_2"},
67+
{0: "d_7_0", 1: "d_7_1", 2: "d_7_2"},
68+
{0: "d_8_0", 1: "d_8_1", 2: "d_8_2"},
69+
],
70+
[
71+
{0: "d_9_0", 1: "d_9_1", 2: "d_9_2"},
72+
{0: "d_10_0", 1: "d_10_1", 2: "d_10_2"},
73+
{0: "d_11_0", 1: "d_11_1", 2: "d_11_2"},
74+
],
75+
],
76+
],
77+
),
78+
(
79+
make_sliding_window_cache(
80+
[
81+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
82+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
83+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
84+
]
85+
),
86+
[
87+
[
88+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
89+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
90+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
91+
],
92+
[
93+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
94+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
95+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
96+
],
97+
],
98+
),
99+
(
100+
make_static_cache(
101+
[
102+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
103+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
104+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
105+
],
106+
max_cache_len=15,
107+
),
108+
[
109+
[
110+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
111+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
112+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
113+
],
114+
[
115+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
116+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
117+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
118+
],
119+
],
120+
),
121+
]
122+
with torch_export_patches(patch_transformers=True):
123+
for cache, exds in caches:
124+
with self.subTest(cache_name=cache.__class__.__name__):
125+
ds = all_dynamic_shape_from_inputs(cache)
126+
self.assertEqual(exds, ds)
127+
12128
@requires_transformers("4.52")
13129
@requires_torch("2.7.99")
14130
def test_all_dynamic_shape_from_inputs(self):
@@ -37,10 +153,10 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
37153
"input_ids": {0: "d_0_0", 1: "d_0_1"},
38154
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
39155
"position_ids": {0: "d_2_0", 1: "d_2_1"},
40-
"past_key_values": {
41-
"key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
42-
"value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
43-
},
156+
"past_key_values": [
157+
[{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
158+
[{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
159+
],
44160
},
45161
ds,
46162
)

0 commit comments

Comments
 (0)