Skip to content

Commit 27a5d7b

Browse files
committed
code
1 parent 8b22076 commit 27a5d7b

File tree

3 files changed

+263
-3
lines changed

3 files changed

+263
-3
lines changed

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
1+
import inspect
12
import unittest
23
import numpy as np
4+
import ml_dtypes
35
import onnx
46
import onnx.helper as oh
57
import torch
68
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows
7-
from onnx_diagnostic.helpers import string_type, string_sig, pretty_onnx, get_onnx_signature
9+
from onnx_diagnostic.helpers import (
10+
string_type,
11+
string_sig,
12+
pretty_onnx,
13+
get_onnx_signature,
14+
flatten_object,
15+
max_diff,
16+
type_info,
17+
size_type,
18+
onnx_dtype_name,
19+
string_signature,
20+
make_hash,
21+
onnx_dtype_to_torch_dtype,
22+
np_dtype_to_tensor_dtype,
23+
torch_dtype_to_onnx_dtype,
24+
from_array_extended,
25+
convert_endian,
26+
from_array_ml_dtypes,
27+
dtype_to_tensor_dtype,
28+
string_diff,
29+
)
830

931
TFLOAT = onnx.TensorProto.FLOAT
1032

@@ -91,6 +113,132 @@ def test_get_onnx_signature(self):
91113
sig = get_onnx_signature(proto)
92114
self.assertEqual(sig, (("X", 1, (1, "b", "c")), ("Y", 1, ("a", "b", "c"))))
93115

116+
def test_flatten(self):
117+
inputs = (
118+
torch.rand((3, 4), dtype=torch.float16),
119+
[
120+
torch.rand((5, 6), dtype=torch.float16),
121+
torch.rand((5, 6, 7), dtype=torch.float16),
122+
],
123+
)
124+
flat = flatten_object(inputs)
125+
diff = max_diff(inputs, flat, flatten=True)
126+
self.assertEqual(diff["abs"], 0)
127+
d = string_diff(diff)
128+
print(d)
129+
130+
def test_type_info(self):
131+
for tt in [
132+
onnx.TensorProto.FLOAT,
133+
onnx.TensorProto.FLOAT16,
134+
onnx.TensorProto.DOUBLE,
135+
onnx.TensorProto.BFLOAT16,
136+
onnx.TensorProto.INT32,
137+
onnx.TensorProto.INT64,
138+
]:
139+
type_info(tt, "min")
140+
type_info(tt, "max")
141+
142+
def test_size_type_onnx(self):
143+
for i in range(1, 40):
144+
with self.subTest(i=i):
145+
try:
146+
name = onnx_dtype_name(i)
147+
except ValueError:
148+
continue
149+
if name not in {"STRING", "UINT4", "INT4", "FLOAT4E2M1"}:
150+
size_type(i)
151+
152+
if name not in {
153+
"STRING",
154+
"UINT4",
155+
"INT4",
156+
"FLOAT4E2M1",
157+
"FLOAT8E5M2FNUZ",
158+
"FLOAT8E5M2",
159+
"FLOAT8E4M3FN",
160+
"FLOAT8E4M3FNUZ",
161+
}:
162+
onnx_dtype_to_torch_dtype(i)
163+
164+
def test_size_type_numpy(self):
165+
for dt in {
166+
np.float32,
167+
np.float64,
168+
np.float16,
169+
np.int32,
170+
np.int64,
171+
np.int8,
172+
np.int16,
173+
np.uint8,
174+
np.uint16,
175+
np.uint32,
176+
np.uint64,
177+
}:
178+
size_type(dt)
179+
np_dtype_to_tensor_dtype(dt)
180+
181+
def test_from_array(self):
182+
for dt in {
183+
np.float32,
184+
np.float64,
185+
np.float16,
186+
np.int32,
187+
np.int64,
188+
np.int8,
189+
np.int16,
190+
np.uint8,
191+
np.uint16,
192+
np.uint32,
193+
np.uint64,
194+
}:
195+
t = np.random.rand(4, 3).astype(dt)
196+
proto = from_array_extended(t)
197+
self.assertIsInstance(proto, onnx.TensorProto)
198+
convert_endian(proto)
199+
dtype_to_tensor_dtype(dt)
200+
201+
def test_from_array_ml_dtypes(self):
202+
for dt in {
203+
ml_dtypes.bfloat16,
204+
}:
205+
t = np.random.rand(4, 3).astype(dt)
206+
from_array_ml_dtypes(t)
207+
from_array_extended(t)
208+
209+
def test_size_type_mldtypes(self):
210+
for dt in {
211+
ml_dtypes.bfloat16,
212+
}:
213+
size_type(dt)
214+
np_dtype_to_tensor_dtype(dt)
215+
dtype_to_tensor_dtype(dt)
216+
217+
def test_size_type_torch(self):
218+
for dt in {
219+
torch.float32,
220+
torch.float64,
221+
torch.float16,
222+
torch.int32,
223+
torch.int64,
224+
torch.int8,
225+
torch.int16,
226+
torch.uint8,
227+
torch.uint16,
228+
torch.uint32,
229+
torch.uint64,
230+
}:
231+
size_type(dt)
232+
torch_dtype_to_onnx_dtype(dt)
233+
dtype_to_tensor_dtype(dt)
234+
235+
def test_string_signature(self):
236+
sig = string_signature(inspect.signature(string_signature))
237+
self.assertIn("sig: typing.Any", sig)
238+
239+
def test_make_hash(self):
240+
self.assertIsInstance(make_hash([]), str)
241+
94242

95243
if __name__ == "__main__":
96244
unittest.main(verbosity=2)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort
8+
9+
TFLOAT = onnx.TensorProto.FLOAT
10+
11+
12+
class TestOrtSession(ExtTestCase):
13+
14+
def test_dummy_llm(self):
15+
for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer"]:
16+
model, inputs = dummy_llm(cls_name)
17+
model(*inputs)
18+
19+
def test_check_model_ort(self):
20+
model = oh.make_model(
21+
oh.make_graph(
22+
[
23+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
24+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
25+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
26+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
27+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
28+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
29+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
30+
],
31+
"dummy",
32+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
33+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
34+
[
35+
onh.from_array(
36+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
37+
),
38+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
39+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
40+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
41+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
42+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
43+
],
44+
),
45+
opset_imports=[oh.make_opsetid("", 18)],
46+
ir_version=9,
47+
)
48+
check_model_ort(model)
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,15 @@ def size_type(dtype: Any) -> int:
4444
TensorProto.UINT16,
4545
}:
4646
return 2
47-
if dtype in {TensorProto.INT8, TensorProto.UINT8, TensorProto.BOOL}:
47+
if dtype in {
48+
TensorProto.INT8,
49+
TensorProto.UINT8,
50+
TensorProto.BOOL,
51+
TensorProto.FLOAT8E4M3FN,
52+
TensorProto.FLOAT8E4M3FNUZ,
53+
TensorProto.FLOAT8E5M2,
54+
TensorProto.FLOAT8E5M2FNUZ,
55+
}:
4856
return 1
4957
if dtype in {TensorProto.COMPLEX128}:
5058
return 16
@@ -56,6 +64,12 @@ def size_type(dtype: Any) -> int:
5664
return 4
5765
if dtype == np.float16 or dtype == np.int16:
5866
return 2
67+
if dtype == np.int16 or dtype == np.uint16:
68+
return 2
69+
if dtype == np.int32 or dtype == np.uint32:
70+
return 4
71+
if dtype == np.int64 or dtype == np.uint64:
72+
return 8
5973
if dtype == np.int8 or dtype == np.uint8:
6074
return 1
6175
if hasattr(np, "uint64"):
@@ -85,6 +99,10 @@ def size_type(dtype: Any) -> int:
8599
return 4
86100
if dtype in {torch.uint16}:
87101
return 2
102+
import ml_dtypes
103+
104+
if dtype == ml_dtypes.bfloat16:
105+
return 2
88106
raise AssertionError(f"Unexpected dtype={dtype}")
89107

90108

@@ -783,7 +801,9 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
783801
return torch.complex64
784802
if itype == TensorProto.COMPLEX128:
785803
return torch.complex128
786-
raise NotImplementedError(f"Unable to convert onnx type {itype} to torch.type.")
804+
raise NotImplementedError(
805+
f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
806+
)
787807

788808

789809
def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
@@ -807,10 +827,22 @@ def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
807827
return TensorProto.INT64
808828
if to == torch.int32:
809829
return TensorProto.INT32
830+
if to == torch.uint64:
831+
return TensorProto.UINT64
832+
if to == torch.uint32:
833+
return TensorProto.UINT32
810834
if to == torch.bool:
811835
return TensorProto.BOOL
812836
if to == torch.SymInt:
813837
return TensorProto.INT64
838+
if to == torch.int16:
839+
return TensorProto.INT16
840+
if to == torch.uint16:
841+
return TensorProto.UINT16
842+
if to == torch.int8:
843+
return TensorProto.INT8
844+
if to == torch.uint8:
845+
return TensorProto.UINT8
814846
if to == torch.SymFloat:
815847
return TensorProto.FLOAT
816848
if to == torch.complex64:
@@ -859,6 +891,34 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
859891
return TensorProto.FLOAT8E5M2
860892
if dt == ml_dtypes.float8_e5m2fnuz:
861893
return TensorProto.FLOAT8E5M2FNUZ
894+
if dt == np.float32:
895+
return TensorProto.FLOAT
896+
if dt == np.float16:
897+
return TensorProto.FLOAT16
898+
if dt == np.float64:
899+
return TensorProto.DOUBLE
900+
if dt == np.int64:
901+
return TensorProto.INT64
902+
if dt == np.uint64:
903+
return TensorProto.UINT64
904+
if dt == np.int16:
905+
return TensorProto.INT16
906+
if dt == np.uint16:
907+
return TensorProto.UINT16
908+
if dt == np.int32:
909+
return TensorProto.INT32
910+
if dt == np.int8:
911+
return TensorProto.INT8
912+
if dt == np.uint8:
913+
return TensorProto.UINT8
914+
if dt == np.uint32:
915+
return TensorProto.UINT32
916+
if dt == np.bool:
917+
return TensorProto.BOOL
918+
if dt == np.complex64:
919+
return TensorProto.COMPLEX64
920+
if dt == np.complex128:
921+
return TensorProto.COMPLEX128
862922
raise ValueError(f"Unable to convert type {dt}")
863923

864924

0 commit comments

Comments
 (0)