Skip to content

Commit f159302

Browse files
authored
fix MiniOnnxBuilder (#301)
* fix MiniOnnxBuilder * fix subtypes * final fix
1 parent f738021 commit f159302

File tree

3 files changed

+89
-18
lines changed

3 files changed

+89
-18
lines changed

_unittests/ut_helpers/test_mini_onnx_builder.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy as np
3+
import ml_dtypes
34
import onnx
45
import torch
56
from onnx_diagnostic.ext_test_case import ExtTestCase
@@ -59,7 +60,7 @@ def test_mini_onnx_builder_sequence_ort_randomize(self):
5960
self.assertEqual((2,), got[1].shape)
6061
self.assertEqual(np.float32, got[1].dtype)
6162

62-
def test_mini_onnx_builder(self):
63+
def test_mini_onnx_builder1(self):
6364
data = [
6465
(
6566
np.array([1, 2], dtype=np.int64),
@@ -153,6 +154,69 @@ def test_mini_onnx_builder(self):
153154
restored = create_input_tensors_from_onnx_model(model)
154155
self.assertEqualAny(inputs, restored)
155156

157+
def test_mini_onnx_builder2(self):
158+
data = [
159+
[],
160+
{},
161+
tuple(),
162+
dict(one=np.array([1, 2], dtype=np.int64)),
163+
[np.array([1, 2], dtype=np.int64)],
164+
(np.array([1, 2], dtype=np.int64),),
165+
[np.array([1, 2], dtype=np.int64), 1],
166+
dict(one=np.array([1, 2], dtype=np.int64), two=1),
167+
(np.array([1, 2], dtype=np.int64), 1),
168+
]
169+
170+
for inputs in data:
171+
with self.subTest(types=string_type(inputs)):
172+
model = create_onnx_model_from_input_tensors(inputs)
173+
restored = create_input_tensors_from_onnx_model(model)
174+
self.assertEqualAny(inputs, restored)
175+
176+
def test_mini_onnx_builder_type1(self):
177+
data = [
178+
np.array([1], dtype=dtype)
179+
for dtype in [
180+
np.float32,
181+
np.double,
182+
np.float16,
183+
np.int32,
184+
np.int64,
185+
np.int16,
186+
np.int8,
187+
np.uint8,
188+
np.uint16,
189+
np.uint32,
190+
np.uint64,
191+
np.bool_,
192+
]
193+
]
194+
195+
for inputs in data:
196+
with self.subTest(types=string_type(inputs)):
197+
model = create_onnx_model_from_input_tensors(inputs)
198+
restored = create_input_tensors_from_onnx_model(model)
199+
self.assertEqualAny(inputs, restored)
200+
201+
def test_mini_onnx_builder_type2(self):
202+
data = [
203+
np.array([1], dtype=dtype)
204+
for dtype in [
205+
ml_dtypes.bfloat16,
206+
ml_dtypes.float8_e4m3fn,
207+
ml_dtypes.float8_e4m3fnuz,
208+
ml_dtypes.float8_e5m2,
209+
ml_dtypes.float8_e5m2fnuz,
210+
ml_dtypes.int4,
211+
]
212+
]
213+
214+
for inputs in data:
215+
with self.subTest(types=string_type(inputs)):
216+
model = create_onnx_model_from_input_tensors(inputs)
217+
restored = create_input_tensors_from_onnx_model(model)
218+
self.assertEqualAny(inputs, restored)
219+
156220
def test_mini_onnx_builder_transformers(self):
157221
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
158222
dc = CacheKeyValue(cache)
@@ -181,7 +245,7 @@ def test_mini_onnx_builder_transformers_sep(self):
181245
restored = create_input_tensors_from_onnx_model(model, sep="#")
182246
self.assertEqualAny(inputs, restored)
183247

184-
def test_specific_data(self):
248+
def test_mini_onnx_bulder_specific_data(self):
185249
data = {
186250
("amain", 0, "I"): (
187251
(
@@ -221,6 +285,8 @@ def test_specific_data(self):
221285
names,
222286
)
223287
restored = create_input_tensors_from_onnx_model(model)
288+
self.assertEqual(len(data), len(restored))
289+
self.assertEqual(list(data), list(restored))
224290
self.assertEqualAny(data, restored)
225291

226292

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ def _make(ty: type, res: Any) -> Any:
563563
return d
564564
return ty(res)
565565

566+
if end and len(res) == 1:
567+
if res[0] is None:
568+
return next_pos, ty()
569+
if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None):
570+
return next_pos, ty()
571+
return next_pos, _make(ty, res)
566572
return next_pos, (
567573
ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
568574
)
@@ -639,6 +645,8 @@ def create_input_tensors_from_onnx_model(
639645
return float(output[0])
640646
if name == "tensor":
641647
return torch.from_numpy(output).to(device)
642-
raise AssertionError(f"Unexpected name {name!r} in {names}")
648+
assert name.startswith(
649+
("list_", "list.", "dict.", "tuple_", "tuple.")
650+
), f"Unexpected name {name!r} in {names}"
643651

644652
return _unflatten(sep, names, got, device=device)[1]

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -671,21 +671,18 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
671671
try:
672672
return oh.np_dtype_to_tensor_dtype(dt)
673673
except ValueError:
674-
try:
675-
import ml_dtypes
676-
except ImportError:
677-
ml_dtypes = None # type: ignore
678-
if ml_dtypes is not None:
679-
if dt == ml_dtypes.bfloat16:
680-
return TensorProto.BFLOAT16
681-
if dt == ml_dtypes.float8_e4m3fn:
682-
return TensorProto.FLOAT8E4M3FN
683-
if dt == ml_dtypes.float8_e4m3fnuz:
684-
return TensorProto.FLOAT8E4M3FNUZ
685-
if dt == ml_dtypes.float8_e5m2:
686-
return TensorProto.FLOAT8E5M2
687-
if dt == ml_dtypes.float8_e5m2fnuz:
688-
return TensorProto.FLOAT8E5M2FNUZ
674+
import ml_dtypes
675+
676+
if dt == ml_dtypes.bfloat16:
677+
return TensorProto.BFLOAT16
678+
if dt == ml_dtypes.float8_e4m3fn:
679+
return TensorProto.FLOAT8E4M3FN
680+
if dt == ml_dtypes.float8_e4m3fnuz:
681+
return TensorProto.FLOAT8E4M3FNUZ
682+
if dt == ml_dtypes.float8_e5m2:
683+
return TensorProto.FLOAT8E5M2
684+
if dt == ml_dtypes.float8_e5m2fnuz:
685+
return TensorProto.FLOAT8E5M2FNUZ
689686
if dt == np.float32:
690687
return TensorProto.FLOAT
691688
if dt == np.float16:

0 commit comments

Comments
 (0)