Skip to content

Commit 74f18cc

Browse files
committed
fix MiniOnnxBuilder
1 parent f738021 commit 74f18cc

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

_unittests/ut_helpers/test_mini_onnx_builder.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy as np
3+
import ml_dtypes
34
import onnx
45
import torch
56
from onnx_diagnostic.ext_test_case import ExtTestCase
@@ -59,7 +60,7 @@ def test_mini_onnx_builder_sequence_ort_randomize(self):
5960
self.assertEqual((2,), got[1].shape)
6061
self.assertEqual(np.float32, got[1].dtype)
6162

62-
def test_mini_onnx_builder(self):
63+
def test_mini_onnx_builder1(self):
6364
data = [
6465
(
6566
np.array([1, 2], dtype=np.int64),
@@ -153,6 +154,61 @@ def test_mini_onnx_builder(self):
153154
restored = create_input_tensors_from_onnx_model(model)
154155
self.assertEqualAny(inputs, restored)
155156

157+
def test_mini_onnx_builder2(self):
158+
data = [
159+
[],
160+
{},
161+
tuple(),
162+
dict(one=np.array([1, 2], dtype=np.int64)),
163+
[np.array([1, 2], dtype=np.int64)],
164+
(np.array([1, 2], dtype=np.int64),),
165+
[np.array([1, 2], dtype=np.int64), 1],
166+
dict(one=np.array([1, 2], dtype=np.int64), two=1),
167+
(np.array([1, 2], dtype=np.int64), 1),
168+
]
169+
170+
for inputs in data:
171+
with self.subTest(types=string_type(inputs)):
172+
model = create_onnx_model_from_input_tensors(inputs)
173+
restored = create_input_tensors_from_onnx_model(model)
174+
self.assertEqualAny(inputs, restored)
175+
176+
def test_mini_onnx_builder_type1(self):
177+
data = [
178+
np.array([1], dtype=dtype)
179+
for dtype in [
180+
np.float32,
181+
np.double,
182+
np.float16,
183+
np.int32,
184+
np.int64,
185+
np.int16,
186+
np.int8,
187+
np.uint8,
188+
np.uint16,
189+
np.uint32,
190+
np.uint64,
191+
np.bool_,
192+
]
193+
]
194+
195+
for inputs in data:
196+
with self.subTest(types=string_type(inputs)):
197+
model = create_onnx_model_from_input_tensors(inputs)
198+
restored = create_input_tensors_from_onnx_model(model)
199+
self.assertEqualAny(inputs, restored)
200+
201+
def test_mini_onnx_builder_type2(self):
202+
data = [
203+
np.array([1], dtype=dtype) for dtype in [ml_dtypes.bfloat16, ml_dtypes.float8_e3m4]
204+
]
205+
206+
for inputs in data:
207+
with self.subTest(types=string_type(inputs)):
208+
model = create_onnx_model_from_input_tensors(inputs)
209+
restored = create_input_tensors_from_onnx_model(model)
210+
self.assertEqualAny(inputs, restored)
211+
156212
def test_mini_onnx_builder_transformers(self):
157213
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
158214
dc = CacheKeyValue(cache)

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ def _make(ty: type, res: Any) -> Any:
563563
return d
564564
return ty(res)
565565

566+
if end and len(res) == 1:
567+
if res[0] is None:
568+
return next_pos, ty()
569+
if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None):
570+
return next_pos, ty()
571+
return next_pos, ty(res)
566572
return next_pos, (
567573
ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
568574
)
@@ -639,6 +645,8 @@ def create_input_tensors_from_onnx_model(
639645
return float(output[0])
640646
if name == "tensor":
641647
return torch.from_numpy(output).to(device)
642-
raise AssertionError(f"Unexpected name {name!r} in {names}")
648+
assert name.startswith(
649+
("list_", "list.", "dict.", "tuple_", "tuple.")
650+
), f"Unexpected name {name!r} in {names}"
643651

644652
return _unflatten(sep, names, got, device=device)[1]

0 commit comments

Comments
 (0)