Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions _unittests/ut_helpers/test_mini_onnx_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import numpy as np
import ml_dtypes
import onnx
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_mini_onnx_builder_sequence_ort_randomize(self):
self.assertEqual((2,), got[1].shape)
self.assertEqual(np.float32, got[1].dtype)

def test_mini_onnx_builder(self):
def test_mini_onnx_builder1(self):
data = [
(
np.array([1, 2], dtype=np.int64),
Expand Down Expand Up @@ -153,6 +154,69 @@ def test_mini_onnx_builder(self):
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder2(self):
data = [
[],
{},
tuple(),
dict(one=np.array([1, 2], dtype=np.int64)),
[np.array([1, 2], dtype=np.int64)],
(np.array([1, 2], dtype=np.int64),),
[np.array([1, 2], dtype=np.int64), 1],
dict(one=np.array([1, 2], dtype=np.int64), two=1),
(np.array([1, 2], dtype=np.int64), 1),
]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder_type1(self):
data = [
np.array([1], dtype=dtype)
for dtype in [
np.float32,
np.double,
np.float16,
np.int32,
np.int64,
np.int16,
np.int8,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.bool_,
]
]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder_type2(self):
data = [
np.array([1], dtype=dtype)
for dtype in [
ml_dtypes.bfloat16,
ml_dtypes.float8_e4m3fn,
ml_dtypes.float8_e4m3fnuz,
ml_dtypes.float8_e5m2,
ml_dtypes.float8_e5m2fnuz,
ml_dtypes.int4,
]
]

for inputs in data:
with self.subTest(types=string_type(inputs)):
model = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqualAny(inputs, restored)

def test_mini_onnx_builder_transformers(self):
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
dc = CacheKeyValue(cache)
Expand Down Expand Up @@ -181,7 +245,7 @@ def test_mini_onnx_builder_transformers_sep(self):
restored = create_input_tensors_from_onnx_model(model, sep="#")
self.assertEqualAny(inputs, restored)

def test_specific_data(self):
def test_mini_onnx_bulder_specific_data(self):
data = {
("amain", 0, "I"): (
(
Expand Down Expand Up @@ -221,6 +285,8 @@ def test_specific_data(self):
names,
)
restored = create_input_tensors_from_onnx_model(model)
self.assertEqual(len(data), len(restored))
self.assertEqual(list(data), list(restored))
self.assertEqualAny(data, restored)


Expand Down
10 changes: 9 additions & 1 deletion onnx_diagnostic/helpers/mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,12 @@ def _make(ty: type, res: Any) -> Any:
return d
return ty(res)

if end and len(res) == 1:
if res[0] is None:
return next_pos, ty()
if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None):
return next_pos, ty()
return next_pos, _make(ty, res)
return next_pos, (
ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
)
Expand Down Expand Up @@ -639,6 +645,8 @@ def create_input_tensors_from_onnx_model(
return float(output[0])
if name == "tensor":
return torch.from_numpy(output).to(device)
raise AssertionError(f"Unexpected name {name!r} in {names}")
assert name.startswith(
("list_", "list.", "dict.", "tuple_", "tuple.")
), f"Unexpected name {name!r} in {names}"

return _unflatten(sep, names, got, device=device)[1]
27 changes: 12 additions & 15 deletions onnx_diagnostic/helpers/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,21 +671,18 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
try:
return oh.np_dtype_to_tensor_dtype(dt)
except ValueError:
try:
import ml_dtypes
except ImportError:
ml_dtypes = None # type: ignore
if ml_dtypes is not None:
if dt == ml_dtypes.bfloat16:
return TensorProto.BFLOAT16
if dt == ml_dtypes.float8_e4m3fn:
return TensorProto.FLOAT8E4M3FN
if dt == ml_dtypes.float8_e4m3fnuz:
return TensorProto.FLOAT8E4M3FNUZ
if dt == ml_dtypes.float8_e5m2:
return TensorProto.FLOAT8E5M2
if dt == ml_dtypes.float8_e5m2fnuz:
return TensorProto.FLOAT8E5M2FNUZ
import ml_dtypes

if dt == ml_dtypes.bfloat16:
return TensorProto.BFLOAT16
if dt == ml_dtypes.float8_e4m3fn:
return TensorProto.FLOAT8E4M3FN
if dt == ml_dtypes.float8_e4m3fnuz:
return TensorProto.FLOAT8E4M3FNUZ
if dt == ml_dtypes.float8_e5m2:
return TensorProto.FLOAT8E5M2
if dt == ml_dtypes.float8_e5m2fnuz:
return TensorProto.FLOAT8E5M2FNUZ
if dt == np.float32:
return TensorProto.FLOAT
if dt == np.float16:
Expand Down
Loading