From 74f18ccc0484124c7cae6a442ab4fe31c6a25ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 13 Nov 2025 00:02:31 +0100 Subject: [PATCH 1/3] fix MiniOnnxBuilder --- .../ut_helpers/test_mini_onnx_builder.py | 58 ++++++++++++++++++- onnx_diagnostic/helpers/mini_onnx_builder.py | 10 +++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_helpers/test_mini_onnx_builder.py b/_unittests/ut_helpers/test_mini_onnx_builder.py index a9843fb7..0f3fbe29 100644 --- a/_unittests/ut_helpers/test_mini_onnx_builder.py +++ b/_unittests/ut_helpers/test_mini_onnx_builder.py @@ -1,5 +1,6 @@ import unittest import numpy as np +import ml_dtypes import onnx import torch from onnx_diagnostic.ext_test_case import ExtTestCase @@ -59,7 +60,7 @@ def test_mini_onnx_builder_sequence_ort_randomize(self): self.assertEqual((2,), got[1].shape) self.assertEqual(np.float32, got[1].dtype) - def test_mini_onnx_builder(self): + def test_mini_onnx_builder1(self): data = [ ( np.array([1, 2], dtype=np.int64), @@ -153,6 +154,61 @@ def test_mini_onnx_builder(self): restored = create_input_tensors_from_onnx_model(model) self.assertEqualAny(inputs, restored) + def test_mini_onnx_builder2(self): + data = [ + [], + {}, + tuple(), + dict(one=np.array([1, 2], dtype=np.int64)), + [np.array([1, 2], dtype=np.int64)], + (np.array([1, 2], dtype=np.int64),), + [np.array([1, 2], dtype=np.int64), 1], + dict(one=np.array([1, 2], dtype=np.int64), two=1), + (np.array([1, 2], dtype=np.int64), 1), + ] + + for inputs in data: + with self.subTest(types=string_type(inputs)): + model = create_onnx_model_from_input_tensors(inputs) + restored = create_input_tensors_from_onnx_model(model) + self.assertEqualAny(inputs, restored) + + def test_mini_onnx_builder_type1(self): + data = [ + np.array([1], dtype=dtype) + for dtype in [ + np.float32, + np.double, + np.float16, + np.int32, + np.int64, + np.int16, + np.int8, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.bool_, + ] + ] + + for inputs in data: + with self.subTest(types=string_type(inputs)): + model = create_onnx_model_from_input_tensors(inputs) + restored = create_input_tensors_from_onnx_model(model) + self.assertEqualAny(inputs, restored) + + def test_mini_onnx_builder_type2(self): + data = [ + np.array([1], dtype=dtype) for dtype in [ml_dtypes.bfloat16, ml_dtypes.float8_e3m4] + ] + + for inputs in data: + with self.subTest(types=string_type(inputs)): + model = create_onnx_model_from_input_tensors(inputs) + restored = create_input_tensors_from_onnx_model(model) + self.assertEqualAny(inputs, restored) + def test_mini_onnx_builder_transformers(self): cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)]) dc = CacheKeyValue(cache) diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index c00baee9..95edf652 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -563,6 +563,12 @@ def _make(ty: type, res: Any) -> Any: return d return ty(res) + if end and len(res) == 1: + if res[0] is None: + return next_pos, ty() + if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None): + return next_pos, ty() + return next_pos, ty(res) return next_pos, ( ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res) ) @@ -639,6 +645,8 @@ def create_input_tensors_from_onnx_model( return float(output[0]) if name == "tensor": return torch.from_numpy(output).to(device) - raise AssertionError(f"Unexpected name {name!r} in {names}") + assert name.startswith( + ("list_", "list.", "dict.", "tuple_", "tuple.") + ), f"Unexpected name {name!r} in {names}" return _unflatten(sep, names, got, device=device)[1] From 6ad93156b9f62d7513020ab3e1334c160278e2bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 13 Nov 2025 10:28:15 +0100 Subject: [PATCH 2/3] fix subtypes --- .../ut_helpers/test_mini_onnx_builder.py | 10 ++++++- onnx_diagnostic/helpers/onnx_helper.py | 27 +++++++++---------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/_unittests/ut_helpers/test_mini_onnx_builder.py b/_unittests/ut_helpers/test_mini_onnx_builder.py index 0f3fbe29..c8a9e27e 100644 --- a/_unittests/ut_helpers/test_mini_onnx_builder.py +++ b/_unittests/ut_helpers/test_mini_onnx_builder.py @@ -200,7 +200,15 @@ def test_mini_onnx_builder_type1(self): def test_mini_onnx_builder_type2(self): data = [ - np.array([1], dtype=dtype) for dtype in [ml_dtypes.bfloat16, ml_dtypes.float8_e3m4] + np.array([1], dtype=dtype) + for dtype in [ + ml_dtypes.bfloat16, + ml_dtypes.float8_e4m3fn, + ml_dtypes.float8_e4m3fnuz, + ml_dtypes.float8_e5m2, + ml_dtypes.float8_e5m2fnuz, + ml_dtypes.int4, + ] ] for inputs in data: diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 52d2ecba..d0e0aaed 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -671,21 +671,18 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821 try: return oh.np_dtype_to_tensor_dtype(dt) except ValueError: - try: - import ml_dtypes - except ImportError: - ml_dtypes = None # type: ignore - if ml_dtypes is not None: - if dt == ml_dtypes.bfloat16: - return TensorProto.BFLOAT16 - if dt == ml_dtypes.float8_e4m3fn: - return TensorProto.FLOAT8E4M3FN - if dt == ml_dtypes.float8_e4m3fnuz: - return TensorProto.FLOAT8E4M3FNUZ - if dt == ml_dtypes.float8_e5m2: - return TensorProto.FLOAT8E5M2 - if dt == ml_dtypes.float8_e5m2fnuz: - return TensorProto.FLOAT8E5M2FNUZ + import ml_dtypes + + if dt == ml_dtypes.bfloat16: + return TensorProto.BFLOAT16 + if dt == ml_dtypes.float8_e4m3fn: + return TensorProto.FLOAT8E4M3FN + if dt == ml_dtypes.float8_e4m3fnuz: + return TensorProto.FLOAT8E4M3FNUZ + if dt == ml_dtypes.float8_e5m2: + return TensorProto.FLOAT8E5M2 + if dt == ml_dtypes.float8_e5m2fnuz: + return TensorProto.FLOAT8E5M2FNUZ if dt == np.float32: return TensorProto.FLOAT if dt == np.float16: From 5eadb4e34b2e3074e1b70d4f32115031ecb0adae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 13 Nov 2025 10:43:10 +0100 Subject: [PATCH 3/3] final fix --- _unittests/ut_helpers/test_mini_onnx_builder.py | 4 +++- onnx_diagnostic/helpers/mini_onnx_builder.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_helpers/test_mini_onnx_builder.py b/_unittests/ut_helpers/test_mini_onnx_builder.py index c8a9e27e..633e1f92 100644 --- a/_unittests/ut_helpers/test_mini_onnx_builder.py +++ b/_unittests/ut_helpers/test_mini_onnx_builder.py @@ -245,7 +245,7 @@ def test_mini_onnx_builder_transformers_sep(self): restored = create_input_tensors_from_onnx_model(model, sep="#") self.assertEqualAny(inputs, restored) - def test_specific_data(self): + def test_mini_onnx_bulder_specific_data(self): data = { ("amain", 0, "I"): ( ( @@ -285,6 +285,8 @@ def test_specific_data(self): names, ) restored = create_input_tensors_from_onnx_model(model) + self.assertEqual(len(data), len(restored)) + self.assertEqual(list(data), list(restored)) self.assertEqualAny(data, restored) diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index 95edf652..d9e526e0 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -568,7 +568,7 @@ def _make(ty: type, res: Any) -> Any: return next_pos, ty() if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None): return next_pos, ty() - return next_pos, ty(res) + return next_pos, _make(ty, res) return next_pos, ( ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res) )