Skip to content

Commit 7d5a876

Browse files
committed
better coverage
1 parent 332a3f5 commit 7d5a876

File tree

7 files changed

+195
-34
lines changed

7 files changed

+195
-34
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from experimental_experiment.args import get_parsed_args
4+
5+
6+
class TestHelpers(ExtTestCase):
7+
def test_args(self):
8+
args = get_parsed_args(
9+
"plot_custom_backend_llama",
10+
config=(
11+
"medium",
12+
"large or medium depending, large means closer to the real model",
13+
),
14+
num_hidden_layers=(1, "number of hidden layers"),
15+
with_mask=(0, "tries with a mask as a secondary input"),
16+
optim=("", "Optimization to apply, empty string for all"),
17+
description="doc",
18+
)
19+
self.assertEqual(args.config, "medium")
20+
self.assertEqual(args.num_hidden_layers, 1)
21+
self.assertEqual(args.with_mask, 0)
22+
self.assertEqual(args.optim, "")
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from_array_ml_dtypes,
2727
dtype_to_tensor_dtype,
2828
string_diff,
29+
rename_dynamic_dimensions,
30+
rename_dynamic_expression,
2931
)
3032

3133
TFLOAT = onnx.TensorProto.FLOAT
@@ -241,6 +243,89 @@ def test_string_signature(self):
241243
def test_make_hash(self):
242244
self.assertIsInstance(make_hash([]), str)
243245

246+
def test_string_type_one(self):
247+
self.assertEqual(string_type(None), "None")
248+
self.assertEqual(string_type([4]), "#1[int]")
249+
self.assertEqual(string_type((4, 5)), "(int,int)")
250+
self.assertEqual(string_type([4] * 100), "#100[int,...]")
251+
self.assertEqual(string_type((4,) * 100), "#100(int,...)")
252+
253+
def test_string_type_at(self):
254+
self.assertEqual(string_type(None), "None")
255+
a = np.array([4, 5], dtype=np.float32)
256+
t = torch.tensor([4, 5], dtype=torch.float32)
257+
self.assertEqual(string_type([a]), "#1[A1r1]")
258+
self.assertEqual(string_type([t]), "#1[T1r1]")
259+
self.assertEqual(string_type((a,)), "(A1r1,)")
260+
self.assertEqual(string_type((t,)), "(T1r1,)")
261+
self.assertEqual(string_type([a] * 100), "#100[A1r1,...]")
262+
self.assertEqual(string_type([t] * 100), "#100[T1r1,...]")
263+
self.assertEqual(string_type((a,) * 100), "#100(A1r1,...)")
264+
self.assertEqual(string_type((t,) * 100), "#100(T1r1,...)")
265+
266+
def test_string_type_at_with_shape(self):
267+
self.assertEqual(string_type(None), "None")
268+
a = np.array([4, 5], dtype=np.float32)
269+
t = torch.tensor([4, 5], dtype=torch.float32)
270+
self.assertEqual(string_type([a], with_shape=True), "#1[A1s2]")
271+
self.assertEqual(string_type([t], with_shape=True), "#1[T1s2]")
272+
self.assertEqual(string_type((a,), with_shape=True), "(A1s2,)")
273+
self.assertEqual(string_type((t,), with_shape=True), "(T1s2,)")
274+
self.assertEqual(string_type([a] * 100, with_shape=True), "#100[A1s2,...]")
275+
self.assertEqual(string_type([t] * 100, with_shape=True), "#100[T1s2,...]")
276+
self.assertEqual(string_type((a,) * 100, with_shape=True), "#100(A1s2,...)")
277+
self.assertEqual(string_type((t,) * 100, with_shape=True), "#100(T1s2,...)")
278+
279+
def test_string_type_at_with_shape_min_max(self):
280+
self.assertEqual(string_type(None), "None")
281+
a = np.array([4, 5], dtype=np.float32)
282+
t = torch.tensor([4, 5], dtype=torch.float32)
283+
self.assertEqual(
284+
string_type([a], with_shape=True, with_min_max=True), "#1[A1s2[4.0,5.0:A4.5]]"
285+
)
286+
self.assertEqual(
287+
string_type([t], with_shape=True, with_min_max=True), "#1[T1s2[4.0,5.0:A4.5]]"
288+
)
289+
self.assertEqual(
290+
string_type((a,), with_shape=True, with_min_max=True), "(A1s2[4.0,5.0:A4.5],)"
291+
)
292+
self.assertEqual(
293+
string_type((t,), with_shape=True, with_min_max=True), "(T1s2[4.0,5.0:A4.5],)"
294+
)
295+
self.assertEqual(
296+
string_type([a] * 100, with_shape=True, with_min_max=True),
297+
"#100[A1s2[4.0,5.0:A4.5],...]",
298+
)
299+
self.assertEqual(
300+
string_type([t] * 100, with_shape=True, with_min_max=True),
301+
"#100[T1s2[4.0,5.0:A4.5],...]",
302+
)
303+
self.assertEqual(
304+
string_type((a,) * 100, with_shape=True, with_min_max=True),
305+
"#100(A1s2[4.0,5.0:A4.5],...)",
306+
)
307+
self.assertEqual(
308+
string_type((t,) * 100, with_shape=True, with_min_max=True),
309+
"#100(T1s2[4.0,5.0:A4.5],...)",
310+
)
311+
312+
def test_pretty_onnx_att(self):
313+
node = oh.make_node("Cast", ["xm2c"], ["xm2"], to=1)
314+
pretty_onnx(node.attribute[0])
315+
316+
def test_rename_dimension(self):
317+
res = rename_dynamic_dimensions(
318+
{"a": {"B", "C"}},
319+
{
320+
"B",
321+
},
322+
)
323+
self.assertEqual(res, {"B": "B", "a": "B", "C": "B"})
324+
325+
def test_rename_dynamic_expression(self):
326+
text = rename_dynamic_expression("a * 10 - a", {"a": "x"})
327+
self.assertEqual(text, "x * 10 - x")
328+
244329

245330
if __name__ == "__main__":
246331
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_onnx_tools.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from onnx import TensorProto
66
from onnx.checker import check_model
77
from onnx_diagnostic.ext_test_case import ExtTestCase
8-
from onnx_diagnostic.onnx_tools import onnx_lighten, onnx_unlighten, onnx_find
8+
from onnx_diagnostic.onnx_tools import (
9+
onnx_lighten,
10+
onnx_unlighten,
11+
onnx_find,
12+
_validate_function,
13+
)
914
from onnx_diagnostic.torch_test_helper import check_model_ort
1015

1116
TFLOAT = TensorProto.FLOAT
@@ -67,6 +72,23 @@ def test_onnx_find(self):
6772
self.assertIn("xm2", res[0].output)
6873
self.assertIn("xm2", res[1].input)
6974

75+
def test__validate_function(self):
76+
new_domain = "custom"
77+
78+
linear_regression = oh.make_function(
79+
new_domain,
80+
"LinearRegression",
81+
["x", "a", "b"],
82+
["y"],
83+
[
84+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
85+
oh.make_node("Add", ["xa", "b"], ["y"]),
86+
],
87+
[oh.make_opsetid("", 14)],
88+
[],
89+
)
90+
_validate_function(linear_regression)
91+
7092

7193
if __name__ == "__main__":
7294
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_torch_test_helper.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
import unittest
22
import numpy as np
3+
import ml_dtypes
34
import onnx
45
import onnx.helper as oh
56
import onnx.numpy_helper as onh
7+
import torch
68
from onnx_diagnostic.ext_test_case import ExtTestCase
7-
from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort
9+
from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort, to_numpy
810

911
TFLOAT = onnx.TensorProto.FLOAT
1012

1113

1214
class TestOrtSession(ExtTestCase):
1315

1416
def test_dummy_llm(self):
15-
for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer"]:
17+
for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer", "LLM"]:
1618
model, inputs = dummy_llm(cls_name)
1719
model(*inputs)
1820

21+
def test_dummy_llm_ds(self):
22+
for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer", "LLM"]:
23+
model, inputs, ds = dummy_llm(cls_name, dynamic_shapes=True)
24+
model(*inputs)
25+
self.assertIsInstance(ds, dict)
26+
27+
def test_dummy_llm_exc(self):
28+
self.assertRaise(lambda: dummy_llm("LLLLLL"), NotImplementedError)
29+
1930
def test_check_model_ort(self):
2031
model = oh.make_model(
2132
oh.make_graph(
@@ -47,6 +58,11 @@ def test_check_model_ort(self):
4758
)
4859
check_model_ort(model)
4960

61+
def test_to_numpy(self):
62+
t = torch.tensor([0, 1], dtype=torch.bfloat16)
63+
a = to_numpy(t)
64+
self.assertEqual(a.dtype, ml_dtypes.bfloat16)
65+
5066

5167
if __name__ == "__main__":
5268
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
import unittest
34
import pandas
@@ -6,6 +7,12 @@
67
ExtTestCase,
78
statistics_on_file,
89
statistics_on_folder,
10+
is_apple,
11+
is_windows,
12+
is_azure,
13+
is_linux,
14+
unit_test_going,
15+
measure_time,
916
)
1017

1118

@@ -52,6 +59,31 @@ def test_statistics_on_folders(self):
5259
self.assertEqual(len(gr.columns), 4)
5360
self.assertEqual(total.shape, (2,))
5461

62+
def test_is(self):
63+
is_apple()
64+
is_windows()
65+
is_azure()
66+
is_linux()
67+
unit_test_going()
68+
69+
def test_measure_time(self):
70+
res = measure_time(lambda: math.cos(0.5))
71+
self.assertIsInstance(res, dict)
72+
self.assertEqual(
73+
set(res),
74+
{
75+
"min_exec",
76+
"max_exec",
77+
"average",
78+
"warmup_time",
79+
"context_size",
80+
"deviation",
81+
"repeat",
82+
"ttime",
83+
"number",
84+
},
85+
)
86+
5587

5688
if __name__ == "__main__":
5789
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from contextlib import redirect_stderr, redirect_stdout
1515
from io import StringIO
1616
from timeit import Timer
17-
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1818
import numpy
1919
from numpy.testing import assert_allclose
2020

@@ -38,28 +38,6 @@ def is_linux() -> bool:
3838
return sys.platform == "linux"
3939

4040

41-
def skipif_transformers(version_to_skip: Union[str, Set[str]], msg: str) -> Callable:
42-
"""Skips a unit test if transformers has a specific version."""
43-
if isinstance(version_to_skip, str):
44-
version_to_skip = {version_to_skip}
45-
import transformers
46-
47-
if transformers.__version__ in version_to_skip:
48-
msg = f"Unstable test. {msg}"
49-
return unittest.skip(msg)
50-
return lambda x: x
51-
52-
53-
def skipif_not_onnxrt(msg) -> Callable:
54-
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
55-
UNITTEST_ONNXRT = os.environ.get("UNITTEST_ONNXRT", "0")
56-
value = int(UNITTEST_ONNXRT)
57-
if not value:
58-
msg = f"Set UNITTEST_ONNXRT=1 to run the unittest. {msg}"
59-
return unittest.skip(msg)
60-
return lambda x: x
61-
62-
6341
def skipif_ci_windows(msg) -> Callable:
6442
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
6543
if is_windows() and is_azure():

onnx_diagnostic/helpers.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ def size_type(dtype: Any) -> int:
6666
return 4
6767
if dtype == np.float16 or dtype == np.int16:
6868
return 2
69-
if dtype == np.int16 or dtype == np.uint16:
69+
if dtype == np.int16:
7070
return 2
71-
if dtype == np.int32 or dtype == np.uint32:
71+
if dtype == np.int32:
7272
return 4
73-
if dtype == np.int64 or dtype == np.uint64:
73+
if dtype == np.int64:
7474
return 8
75-
if dtype == np.int8 or dtype == np.uint8:
75+
if dtype == np.int8:
7676
return 1
7777
if hasattr(np, "uint64"):
7878
# it fails on mac
@@ -82,6 +82,8 @@ def size_type(dtype: Any) -> int:
8282
return 4
8383
if dtype == np.uint16:
8484
return 2
85+
if dtype == np.uint8:
86+
return 1
8587

8688
import torch
8789

@@ -225,7 +227,7 @@ def string_type(
225227
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
226228
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
227229
return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]"
228-
return f"({tt},...)#{len(obj)}" if with_shape else f"({tt},...)"
230+
return f"#{len(obj)}({tt},...)"
229231
if isinstance(obj, list):
230232
if len(obj) < limit:
231233
js = ",".join(
@@ -250,8 +252,8 @@ def string_type(
250252
)
251253
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
252254
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
253-
return f"[{tt},...]#{len(obj)}[{mini},{maxi}:{avg}]"
254-
return f"[{tt},...]#{len(obj)}" if with_shape else f"[{tt},...]"
255+
return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]"
256+
return f"#{len(obj)}[{tt},...]"
255257
if isinstance(obj, set):
256258
if len(obj) < 10:
257259
js = ",".join(
@@ -932,7 +934,7 @@ def rename_dynamic_dimensions(
932934
many names for dynamic dimensions. When building the onnx model,
933935
some of them are redundant and can be replaced by the name provided by the user.
934936
935-
:param constraints: exhaustive list of used name and all the values equal to it
937+
:param constraints: exhaustive list of used names and all the values equal to it
936938
:param original: the names to use if possible
937939
:param ban_prefix: avoid any rewriting by a constant starting with this prefix
938940
:return: replacement dictionary

0 commit comments

Comments
 (0)