Skip to content

Commit 8b22076

Browse files
committed
extend code coverage
1 parent 69acc8d commit 8b22076

File tree

3 files changed

+124
-6
lines changed

3 files changed

+124
-6
lines changed

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import unittest
22
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import torch
36
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows
4-
from onnx_diagnostic.helpers import string_type, string_sig
7+
from onnx_diagnostic.helpers import string_type, string_sig, pretty_onnx, get_onnx_signature
8+
9+
TFLOAT = onnx.TensorProto.FLOAT
510

611

712
class TestHelpers(ExtTestCase):
@@ -19,8 +24,6 @@ def test_string_dict(self):
1924
self.assertEqual(s, "dict(a:A1r1,b:dict(r:float),c:{int})")
2025

2126
def test_string_type_array(self):
22-
import torch
23-
2427
a = np.array([1], dtype=np.float32)
2528
t = torch.tensor([1])
2629
obj = {"a": a, "b": t}
@@ -30,22 +33,64 @@ def test_string_type_array(self):
3033
self.assertEqual(s, "dict(a:A1s1,b:T7s1)")
3134

3235
def test_string_sig_f(self):
33-
3436
def f(a, b=3, c=4, e=5):
3537
pass
3638

3739
ssig = string_sig(f, {"a": 1, "c": 8, "b": 3})
3840
self.assertEqual(ssig, "f(a=1, c=8)")
3941

4042
def test_string_sig_cls(self):
41-
4243
class A:
4344
def __init__(self, a, b=3, c=4, e=5):
4445
self.a, self.b, self.c, self.e = a, b, c, e
4546

4647
ssig = string_sig(A(1, c=8))
4748
self.assertEqual(ssig, "A(a=1, c=8)")
4849

50+
def test_pretty_onnx(self):
51+
proto = oh.make_model(
52+
oh.make_graph(
53+
[
54+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
55+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
56+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
57+
],
58+
"nd",
59+
[
60+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
61+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
62+
],
63+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
64+
),
65+
opset_imports=[oh.make_opsetid("", 18)],
66+
ir_version=9,
67+
)
68+
pretty_onnx(proto, shape_inference=True)
69+
pretty_onnx(proto.graph.input[0])
70+
pretty_onnx(proto.graph)
71+
pretty_onnx(proto.graph.node[0])
72+
73+
def test_get_onnx_signature(self):
74+
proto = oh.make_model(
75+
oh.make_graph(
76+
[
77+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
78+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
79+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
80+
],
81+
"nd",
82+
[
83+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
84+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
85+
],
86+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
87+
),
88+
opset_imports=[oh.make_opsetid("", 18)],
89+
ir_version=9,
90+
)
91+
sig = get_onnx_signature(proto)
92+
self.assertEqual(sig, (("X", 1, (1, "b", "c")), ("Y", 1, ("a", "b", "c"))))
93+
4994

5095
if __name__ == "__main__":
5196
unittest.main(verbosity=2)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import unittest
2+
import numpy as np
3+
import onnx.helper as oh
4+
import onnx.numpy_helper as onh
5+
from onnx import TensorProto
6+
from onnx.checker import check_model
7+
from onnx_diagnostic.ext_test_case import ExtTestCase
8+
from onnx_diagnostic.onnx_tools import onnx_lighten, onnx_unlighten, onnx_find
9+
from onnx_diagnostic.torch_test_helper import check_model_ort
10+
11+
TFLOAT = TensorProto.FLOAT
12+
13+
14+
class TestOnnxTools(ExtTestCase):
15+
16+
def _get_model(self):
17+
model = oh.make_model(
18+
oh.make_graph(
19+
[
20+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
21+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
22+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
23+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
24+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
25+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
26+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
27+
],
28+
"dummy",
29+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
30+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
31+
[
32+
onh.from_array(
33+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
34+
),
35+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
36+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
37+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
38+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
39+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
40+
],
41+
),
42+
opset_imports=[oh.make_opsetid("", 18)],
43+
ir_version=9,
44+
)
45+
return model
46+
47+
def test_un_lighten_model(self):
48+
model = self._get_model()
49+
check_model(model)
50+
size1 = len(model.SerializeToString())
51+
(onx, stats), out, _ = self.capture(lambda: onnx_lighten(model, verbose=1))
52+
self.assertIsInstance(stats, dict)
53+
self.assertEqual(len(stats), 1)
54+
self.assertIsInstance(stats["Y"], dict)
55+
self.assertIn("remove initializer", out)
56+
# check_model(onx)
57+
new_model = onnx_unlighten(onx, stats)
58+
check_model(new_model)
59+
size2 = len(new_model.SerializeToString())
60+
self.assertEqual(size1, size2)
61+
check_model_ort(model)
62+
63+
def test_onnx_find(self):
64+
model = self._get_model()
65+
res = onnx_find(model, watch={"xm2"})
66+
self.assertEqual(len(res), 2)
67+
self.assertIn("xm2", res[0].output)
68+
self.assertIn("xm2", res[1].input)
69+
70+
71+
if __name__ == "__main__":
72+
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
77
import numpy as np
88
import numpy.typing as npt
9+
import onnx
910
from onnx import (
1011
AttributeProto,
1112
FunctionProto,
@@ -528,7 +529,7 @@ def pretty_onnx(
528529
assert onx is not None, "onx cannot be None"
529530

530531
if shape_inference:
531-
onx = onx.shape_inference.infer_shapes(onx)
532+
onx = onnx.shape_inference.infer_shapes(onx)
532533

533534
if isinstance(onx, ValueInfoProto):
534535
name = onx.name

0 commit comments

Comments
 (0)