diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ee206cfb..c3778de4 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,12 +1,16 @@ Change Logs =========== +0.7.4 ++++++ + +* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs + 0.7.3 +++++ * :pr:`173`: fixes function to_any for BaseModelOutput - 0.7.2 +++++ diff --git a/_doc/index.rst b/_doc/index.rst index 8cdfd4e6..13ff9641 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -21,7 +21,8 @@ onnx-diagnostic: investigate onnx models The main feature is about `patches `_: it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches. Sources available at `github/onnx-diagnostic `_. -Patches can be enabled as follows: +Patches can be enabled as follows with function +:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`: .. code-block:: python @@ -31,7 +32,8 @@ Patches can be enabled as follows: ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) # ... -Dynamic shapes are difficult to guess for caches, one function +Dynamic shapes are difficult to guess for caches, function +:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs` returns a structure defining all dimensions as dynamic. You need then to remove those which are not dynamic in your model. @@ -237,7 +239,7 @@ The function replaces dynamic dimensions defined as strings by Older versions ============== -* `0.7.3 <../v0.7.3/index.html>`_ +* `0.7.4 <../v0.7.4/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ * `0.4.4 <../v0.4.4/index.html>`_ diff --git a/_doc/recipes/plot_dynamic_shapes_json.py b/_doc/recipes/plot_dynamic_shapes_json.py index e995d8ca..1d00df2e 100644 --- a/_doc/recipes/plot_dynamic_shapes_json.py +++ b/_doc/recipes/plot_dynamic_shapes_json.py @@ -75,11 +75,21 @@ def flatten_unflatten_like_dynamic_shapes(obj): value = flatten_unflatten_like_dynamic_shapes(value) subtrees.append(value) start = end - if spec.type is dict or spec.context: + if spec.type is dict: + # This a dictionary. return dict(zip(spec.context, subtrees)) if spec.type is tuple: return tuple(subtrees) - return subtrees + if spec.type is list: + return list(subtrees) + if spec.context: + # This is a custom class with attributes. + # It is returned as a list. + return list(subtrees) + raise ValueError( + f"Unable to interpret spec type {spec.type} " + f"(type is {type(spec.type)}, context is {spec.context})." + ) def _align(inputs, ds): diff --git a/_scripts/test_backend_onnxruntime.py b/_scripts/test_backend_onnxruntime.py new file mode 100644 index 00000000..48bb1777 --- /dev/null +++ b/_scripts/test_backend_onnxruntime.py @@ -0,0 +1,154 @@ +""" +This file runs through the backend test and evaluates onnxruntime. +""" + +import unittest +import warnings +from typing import Any +import numpy +import onnx.backend.base +import onnx.backend.test +import onnx.shape_inference +import onnx.version_converter +from onnx import ModelProto +from onnx.backend.base import Device, DeviceType +from onnx.defs import onnx_opset_version +import onnxruntime + +ORT_OPSET = max(23, onnx_opset_version() - 2) + + +class OnnxruntimeBackendRep(onnx.backend.base.BackendRep): + def __init__(self, session): + self._session = session + + def run(self, inputs, **kwargs): + if isinstance(inputs, numpy.ndarray): + inputs = [inputs] + if isinstance(inputs, list): + if len(inputs) == len(self._session.input_names): + feeds = dict(zip(self._session.input_names, inputs)) + else: + feeds = {} + pos_inputs = 0 + for inp, tshape in zip(self._session.input_names, self._session.input_types): + shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) + if shape == inputs[pos_inputs].shape: + feeds[inp] = inputs[pos_inputs] + pos_inputs += 1 + if pos_inputs >= len(inputs): + break + elif isinstance(inputs, dict): + feeds = inputs + else: + raise TypeError(f"Unexpected input type {type(inputs)!r}.") + outs = self._session.run(None, feeds) + return outs + + +class OnnxruntimeBackend(onnx.backend.base.Backend): + @classmethod + def is_compatible(cls, model) -> bool: + return all(not (d.domain == "" and d.version > ORT_OPSET) for d in model.opset_import) + + @classmethod + def supports_device(cls, device: str) -> bool: + d = Device(device) + if d == DeviceType.CPU: + return True + if d == DeviceType.CUDA: + import torch + + return torch.cuda.is_available() + return False + + @classmethod + def create_inference_session(cls, model, device): + d = Device(device) + if d == DeviceType.CUDA: + providers = ["CUDAExecutionProvider"] + elif d == DeviceType.CPU: + providers = ["CPUExecutionProvider"] + else: + raise ValueError(f"Unrecognized device {device!r} or {d!r}") + return onnxruntime.InferenceSession(model.SerializeToString(), providers=providers) + + @classmethod + def prepare(cls, model: Any, device: str = "CPU", **kwargs: Any) -> OnnxruntimeBackendRep: + if isinstance(model, onnxruntime.InferenceSession): + return OnnxruntimeBackendRep(model) + if isinstance(model, (str, bytes, ModelProto)): + inf = cls.create_inference_session(model, device) + return cls.prepare(inf, device, **kwargs) + raise TypeError(f"Unexpected type {type(model)} for model.") + + @classmethod + def run_model(cls, model, inputs, device=None, **kwargs): + rep = cls.prepare(model, device, **kwargs) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return rep.run(inputs, **kwargs) + + @classmethod + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): + raise NotImplementedError("Unable to run the model node by node.") + + +dft_atol = 1e-3 +stft_atol = 1e-4 +ql_atol = 1e-5 +backend_test = onnx.backend.test.BackendTest( + OnnxruntimeBackend, + __name__, + test_kwargs={ + "test_dft": {"atol": dft_atol, "rtol": numpy.inf}, + "test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf}, + "test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf}, + "test_dft_inverse": {"atol": dft_atol, "rtol": numpy.inf}, + "test_dft_inverse_opset19": {"atol": dft_atol, "rtol": numpy.inf}, + "test_dft_opset19": {"atol": dft_atol, "rtol": numpy.inf}, + "test_stft": {"atol": stft_atol, "rtol": numpy.inf}, + "test_stft_with_window": {"atol": stft_atol, "rtol": numpy.inf}, + "test_qlinearmatmul_2D_int8_float32": {"atol": ql_atol}, + "test_qlinearmatmul_3D_int8_float32": {"atol": ql_atol}, + }, +) + +# The following tests are too slow with the reference implementation (Conv). +backend_test.exclude( + "(test_bvlc_alexnet" + "|test_densenet121" + "|test_inception_v1" + "|test_inception_v2" + "|test_resnet50" + "|test_shufflenet" + "|test_squeezenet" + "|test_vgg19" + "|test_zfnet512)" +) + +# The following tests cannot pass because they consists in generating random number. +backend_test.exclude("(test_bernoulli|test_PoissonNLLLLoss)") + +# The following tests are not supported. +backend_test.exclude("test_gradient") + +backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)") + + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test.test_cases) + +if __name__ == "__main__": + res = unittest.main(verbosity=2, exit=False) + tests_run = res.result.testsRun + errors = len(res.result.errors) + skipped = len(res.result.skipped) + unexpected_successes = len(res.result.unexpectedSuccesses) + expected_failures = len(res.result.expectedFailures) + print("---------------------------------") + print( + f"tests_run={tests_run} errors={errors} skipped={skipped} " + f"unexpected_successes={unexpected_successes} " + f"expected_failures={expected_failures}" + ) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 30c9581c..2917fa7b 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -5,10 +5,126 @@ all_dynamic_shape_from_inputs, guess_dynamic_shapes_from_inputs, ) +from onnx_diagnostic.helpers.cache_helper import ( + make_dynamic_cache, + make_sliding_window_cache, + make_encoder_decoder_cache, + make_static_cache, +) from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches class TestShapeHelper(ExtTestCase): + + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_all_dynamic_shape_from_cache(self): + cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]) + ds = all_dynamic_shape_from_inputs(cache) + self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds) + + @requires_torch("2.7.99") + def test_all_dynamic_shape_all_transformers_cache(self): + caches = [ + ( + make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]), + [[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], + ), + ( + make_encoder_decoder_cache( + make_dynamic_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ), + make_dynamic_cache( + [ + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + ] + ), + ), + [ + [ + [ + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2"}, + ], + [ + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2"}, + ], + ], + [ + [ + {0: "d_6_0", 1: "d_6_1", 2: "d_6_2"}, + {0: "d_7_0", 1: "d_7_1", 2: "d_7_2"}, + {0: "d_8_0", 1: "d_8_1", 2: "d_8_2"}, + ], + [ + {0: "d_9_0", 1: "d_9_1", 2: "d_9_2"}, + {0: "d_10_0", 1: "d_10_1", 2: "d_10_2"}, + {0: "d_11_0", 1: "d_11_1", 2: "d_11_2"}, + ], + ], + ], + ), + ( + make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ), + [ + [ + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, + ], + [ + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, + ], + ], + ), + ( + make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + max_cache_len=15, + ), + [ + [ + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, + ], + [ + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, + ], + ], + ), + ] + with torch_export_patches(patch_transformers=True): + for cache, exds in caches: + with self.subTest(cache_name=cache.__class__.__name__): + ds = all_dynamic_shape_from_inputs(cache) + self.assertEqual(exds, ds) + @requires_transformers("4.52") @requires_torch("2.7.99") def test_all_dynamic_shape_from_inputs(self): @@ -37,10 +153,10 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self): "input_ids": {0: "d_0_0", 1: "d_0_1"}, "attention_mask": {0: "d_1_0", 1: "d_1_1"}, "position_ids": {0: "d_2_0", 1: "d_2_1"}, - "past_key_values": { - "key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}], - "value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}], - }, + "past_key_values": [ + [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}], + [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}], + ], }, ds, ) diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index 0fcef585..9f596c4f 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -50,11 +50,24 @@ def is_compatible(cls, model) -> bool: @classmethod def supports_device(cls, device: str) -> bool: d = Device(device) - return d.type == DeviceType.CPU + if d == DeviceType.CPU: + return True + if d == DeviceType.CUDA: + import torch + + return torch.cuda.is_available() + return False @classmethod - def create_inference_session(cls, model): - return OnnxruntimeEvaluator(model) + def create_inference_session(cls, model, device): + d = Device(device) + if d == DeviceType.CUDA: + providers = ["CUDAExecutionProvider"] + elif d == DeviceType.CPU: + providers = ["CPUExecutionProvider"] + else: + raise ValueError(f"Unrecognized device {device!r} or {d!r}") + return OnnxruntimeEvaluator(model, providers=providers) @classmethod def prepare( @@ -63,7 +76,7 @@ def prepare( if isinstance(model, OnnxruntimeEvaluator): return OnnxruntimeEvaluatorBackendRep(model) if isinstance(model, (str, bytes, ModelProto)): - inf = cls.create_inference_session(model) + inf = cls.create_inference_session(model, device) return cls.prepare(inf, device, **kwargs) raise TypeError(f"Unexpected type {type(model)} for model.") diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py index 7a2475da..67db9825 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py @@ -49,8 +49,8 @@ def test_unet_2d_condition_output(self): # flatten_unflatten flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) - self.assertIsInstance(unflat, dict) - self.assertEqual(list(unflat), ["sample"]) + self.assertIsInstance(unflat, list) + self.assertEqual("#1[T1r3]", self.string_type(unflat)) # export class Model(torch.nn.Module): diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index f2432c99..e6860004 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -163,8 +163,8 @@ def test_base_model_output_unflatten_flatten(self): with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) - self.assertIsInstance(unflat, dict) - self.assertEqual(list(unflat), ["last_hidden_state"]) + self.assertIsInstance(unflat, list) + self.assertEqual("#1[T1r3]", self.string_type(unflat)) @ignore_warnings(UserWarning) def test_base_sliding_window_cache_unflatten_flatten(self): @@ -260,8 +260,10 @@ def test_static_cache(self): # flatten_unflatten flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) - self.assertIsInstance(unflat, dict) - self.assertEqual(list(unflat), ["key_cache", "value_cache"]) + self.assertIsInstance(unflat, list) + self.assertEqual( + "#2[#3[T1r4,T1r4,T1r4],#3[T1r4,T1r4,T1r4]]", self.string_type(unflat) + ) # export class Model(torch.nn.Module): diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index afa1684b..c1b8822d 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.7.3" +__version__ = "0.7.4" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index 489e65f0..ad59d788 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -30,6 +30,77 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds) + + For this function to work, patches must be enabled if :epkg:`transformers` + does not implement the serialization functions. + + .. runpython:: + :showcode: + + import pprint + import torch + from onnx_diagnostic.helpers.cache_helper import ( + make_dynamic_cache, + make_encoder_decoder_cache, + make_mamba_cache, + make_sliding_window_cache, + make_static_cache, + ) + from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs + from onnx_diagnostic.torch_export_patches import torch_export_patches + + caches = [ + make_dynamic_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ), + make_encoder_decoder_cache( + make_dynamic_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ), + make_dynamic_cache( + [ + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), + ] + ), + ), + make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ), + make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + max_cache_len=15, + ), + make_mamba_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ), + ] + + with torch_export_patches(patch_transformers=True): + for cache in caches: + print(f"-- {cache.__class__.__name__}") + pprint.pprint(all_dynamic_shape_from_inputs(cache)) """ if isinstance(dim_prefix, str): prefixes: Set[str] = set() diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 99413691..820983a4 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -39,11 +39,21 @@ def flatten_unflatten_for_dynamic_shapes( subtrees.append(value) start = end if use_dict: - if spec.type is dict or spec.context: + if spec.type is dict: # This a dictionary. return dict(zip(spec.context, subtrees)) if spec.type is tuple: return tuple(subtrees) + if spec.type is list: + return list(subtrees) + if spec.context: + # This is a custom class with attributes. + # It is returned as a list. + return list(subtrees) + raise ValueError( + f"Unable to interpret spec type {spec.type} " + f"(type is {type(spec.type)}, context is {spec.context})." + ) # This is a list. return subtrees diff --git a/onnx_diagnostic/reference/ops/op_cast_like.py b/onnx_diagnostic/reference/ops/op_cast_like.py index 87acdce1..5b5d3d18 100644 --- a/onnx_diagnostic/reference/ops/op_cast_like.py +++ b/onnx_diagnostic/reference/ops/op_cast_like.py @@ -1,13 +1,17 @@ from onnx.onnx_pb import TensorProto from onnx.reference.op_run import OpRun -from onnx.reference.ops.op_cast import ( - bfloat16, - cast_to, - float8e4m3fn, - float8e4m3fnuz, - float8e5m2, - float8e5m2fnuz, -) + +try: + from onnx.reference.ops.op_cast import ( + bfloat16, + cast_to, + float8e4m3fn, + float8e4m3fnuz, + float8e5m2, + float8e5m2fnuz, + ) +except ImportError: + from onnx.reference.ops.op_cast import cast_to from ...helpers.onnx_helper import np_dtype_to_tensor_dtype diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 693de5c7..179b1f05 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -18,7 +18,6 @@ from ..tasks import random_input_kwargs from ..torch_export_patches import torch_export_patches from ..torch_export_patches.patch_inputs import use_dyn_not_str -from ..reference import TorchOnnxEvaluator from .hghub import get_untrained_model_with_inputs @@ -1113,6 +1112,9 @@ def _mk(key): f"{providers}..., flavour={flavour!r}" ) + if runtime != "onnxruntime": + from ..reference import TorchOnnxEvaluator + cls_runtime = ( ( lambda model, providers: onnxruntime.InferenceSession( @@ -1122,7 +1124,7 @@ def _mk(key): ) if runtime == "onnxruntime" else ( - lambda model, providers: TorchOnnxEvaluator( + lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc] model, providers=providers, verbose=max(verbose - 1, 0) ) )