Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
Change Logs
===========

0.7.4
+++++

* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs

0.7.3
+++++

* :pr:`173`: fixes function to_any for BaseModelOutput


0.7.2
+++++

Expand Down
8 changes: 5 additions & 3 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ onnx-diagnostic: investigate onnx models
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-diagnostic/>`_.
Patches can be enabled as follows:
Patches can be enabled as follows with function
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`:

.. code-block:: python

Expand All @@ -31,7 +32,8 @@ Patches can be enabled as follows:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
# ...

Dynamic shapes are difficult to guess for caches, one function
Dynamic shapes are difficult to guess for caches, function
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
returns a structure defining all dimensions as dynamic.
You need then to remove those which are not dynamic in your model.

Expand Down Expand Up @@ -237,7 +239,7 @@ The function replaces dynamic dimensions defined as strings by
Older versions
==============

* `0.7.3 <../v0.7.3/index.html>`_
* `0.7.4 <../v0.7.4/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
* `0.4.4 <../v0.4.4/index.html>`_
Expand Down
14 changes: 12 additions & 2 deletions _doc/recipes/plot_dynamic_shapes_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,21 @@ def flatten_unflatten_like_dynamic_shapes(obj):
value = flatten_unflatten_like_dynamic_shapes(value)
subtrees.append(value)
start = end
if spec.type is dict or spec.context:
if spec.type is dict:
# This a dictionary.
return dict(zip(spec.context, subtrees))
if spec.type is tuple:
return tuple(subtrees)
return subtrees
if spec.type is list:
return list(subtrees)
if spec.context:
# This is a custom class with attributes.
# It is returned as a list.
return list(subtrees)
raise ValueError(
f"Unable to interpret spec type {spec.type} "
f"(type is {type(spec.type)}, context is {spec.context})."
)


def _align(inputs, ds):
Expand Down
154 changes: 154 additions & 0 deletions _scripts/test_backend_onnxruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
This file runs through the backend test and evaluates onnxruntime.
"""

import unittest
import warnings
from typing import Any
import numpy
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto
from onnx.backend.base import Device, DeviceType
from onnx.defs import onnx_opset_version
import onnxruntime

ORT_OPSET = max(23, onnx_opset_version() - 2)


class OnnxruntimeBackendRep(onnx.backend.base.BackendRep):
def __init__(self, session):
self._session = session

def run(self, inputs, **kwargs):
if isinstance(inputs, numpy.ndarray):
inputs = [inputs]
if isinstance(inputs, list):
if len(inputs) == len(self._session.input_names):
feeds = dict(zip(self._session.input_names, inputs))
else:
feeds = {}
pos_inputs = 0
for inp, tshape in zip(self._session.input_names, self._session.input_types):
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
if shape == inputs[pos_inputs].shape:
feeds[inp] = inputs[pos_inputs]
pos_inputs += 1
if pos_inputs >= len(inputs):
break
elif isinstance(inputs, dict):
feeds = inputs
else:
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
outs = self._session.run(None, feeds)
return outs


class OnnxruntimeBackend(onnx.backend.base.Backend):
@classmethod
def is_compatible(cls, model) -> bool:
return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import)

@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
if d == DeviceType.CPU:
return True
if d == DeviceType.CUDA:
import torch

return torch.cuda.is_available()
return False

@classmethod
def create_inference_session(cls, model, device):
d = Device(device)
if d == DeviceType.CUDA:
providers = ["CUDAExecutionProvider"]
elif d == DeviceType.CPU:
providers = ["CPUExecutionProvider"]
else:
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
return onnxruntime.InferenceSession(model.SerializeToString(), providers=providers)

@classmethod
def prepare(cls, model: Any, device: str = "CPU", **kwargs: Any) -> OnnxruntimeBackendRep:
if isinstance(model, onnxruntime.InferenceSession):
return OnnxruntimeBackendRep(model)
if isinstance(model, (str, bytes, ModelProto)):
inf = cls.create_inference_session(model, device)
return cls.prepare(inf, device, **kwargs)
raise TypeError(f"Unexpected type {type(model)} for model.")

@classmethod
def run_model(cls, model, inputs, device=None, **kwargs):
rep = cls.prepare(model, device, **kwargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return rep.run(inputs, **kwargs)

@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
raise NotImplementedError("Unable to run the model node by node.")


dft_atol = 1e-3
stft_atol = 1e-4
ql_atol = 1e-5
backend_test = onnx.backend.test.BackendTest(
OnnxruntimeBackend,
__name__,
test_kwargs={
"test_dft": {"atol": dft_atol, "rtol": numpy.inf},
"test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf},
"test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf},
"test_dft_inverse": {"atol": dft_atol, "rtol": numpy.inf},
"test_dft_inverse_opset19": {"atol": dft_atol, "rtol": numpy.inf},
"test_dft_opset19": {"atol": dft_atol, "rtol": numpy.inf},
"test_stft": {"atol": stft_atol, "rtol": numpy.inf},
"test_stft_with_window": {"atol": stft_atol, "rtol": numpy.inf},
"test_qlinearmatmul_2D_int8_float32": {"atol": ql_atol},
"test_qlinearmatmul_3D_int8_float32": {"atol": ql_atol},
},
)

# The following tests are too slow with the reference implementation (Conv).
backend_test.exclude(
"(test_bvlc_alexnet"
"|test_densenet121"
"|test_inception_v1"
"|test_inception_v2"
"|test_resnet50"
"|test_shufflenet"
"|test_squeezenet"
"|test_vgg19"
"|test_zfnet512)"
)

# The following tests cannot pass because they consists in generating random number.
backend_test.exclude("(test_bernoulli|test_PoissonNLLLLoss)")

# The following tests are not supported.
backend_test.exclude("test_gradient")

backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)")


# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)

if __name__ == "__main__":
res = unittest.main(verbosity=2, exit=False)
tests_run = res.result.testsRun
errors = len(res.result.errors)
skipped = len(res.result.skipped)
unexpected_successes = len(res.result.unexpectedSuccesses)
expected_failures = len(res.result.expectedFailures)
print("---------------------------------")
print(
f"tests_run={tests_run} errors={errors} skipped={skipped} "
f"unexpected_successes={unexpected_successes} "
f"expected_failures={expected_failures}"
)
124 changes: 120 additions & 4 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,126 @@
all_dynamic_shape_from_inputs,
guess_dynamic_shapes_from_inputs,
)
from onnx_diagnostic.helpers.cache_helper import (
make_dynamic_cache,
make_sliding_window_cache,
make_encoder_decoder_cache,
make_static_cache,
)
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class TestShapeHelper(ExtTestCase):

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_cache(self):
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
ds = all_dynamic_shape_from_inputs(cache)
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)

@requires_torch("2.7.99")
def test_all_dynamic_shape_all_transformers_cache(self):
caches = [
(
make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]),
[[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]],
),
(
make_encoder_decoder_cache(
make_dynamic_cache(
[
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
]
),
make_dynamic_cache(
[
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
]
),
),
[
[
[
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2"},
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2"},
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2"},
],
[
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2"},
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2"},
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2"},
],
],
[
[
{0: "d_6_0", 1: "d_6_1", 2: "d_6_2"},
{0: "d_7_0", 1: "d_7_1", 2: "d_7_2"},
{0: "d_8_0", 1: "d_8_1", 2: "d_8_2"},
],
[
{0: "d_9_0", 1: "d_9_1", 2: "d_9_2"},
{0: "d_10_0", 1: "d_10_1", 2: "d_10_2"},
{0: "d_11_0", 1: "d_11_1", 2: "d_11_2"},
],
],
],
),
(
make_sliding_window_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
),
[
[
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
],
[
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
],
],
),
(
make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
],
max_cache_len=15,
),
[
[
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
],
[
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
],
],
),
]
with torch_export_patches(patch_transformers=True):
for cache, exds in caches:
with self.subTest(cache_name=cache.__class__.__name__):
ds = all_dynamic_shape_from_inputs(cache)
self.assertEqual(exds, ds)

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_inputs(self):
Expand Down Expand Up @@ -37,10 +153,10 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
"input_ids": {0: "d_0_0", 1: "d_0_1"},
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
"position_ids": {0: "d_2_0", 1: "d_2_1"},
"past_key_values": {
"key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
"value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
},
"past_key_values": [
[{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
[{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
],
},
ds,
)
Expand Down
Loading
Loading