Skip to content

Commit cd111f5

Browse files
authored
extend coverage (#6)
* extend coverage * fix
1 parent fd6092f commit cd111f5

File tree

6 files changed

+100
-11
lines changed

6 files changed

+100
-11
lines changed

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
requires_transformers,
66
skipif_ci_windows,
77
ignore_warnings,
8+
hide_stdout,
89
)
910
from onnx_diagnostic.helpers import string_type
1011
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -16,6 +17,7 @@ class TestOnnxExportErrors(ExtTestCase):
1617
@requires_transformers("4.49.999")
1718
@skipif_ci_windows("not working on Windows")
1819
@ignore_warnings(UserWarning)
20+
@hide_stdout()
1921
def test_pytree_flatten_mamba_cache(self):
2022
import torch
2123
import torch.utils._pytree as py_pytree
@@ -31,7 +33,7 @@ def __init__(self):
3133

3234
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
3335

34-
with bypass_export_some_errors():
36+
with bypass_export_some_errors(verbose=1):
3537
values, spec = py_pytree.tree_flatten(cache)
3638
cache2 = py_pytree.tree_unflatten(values, spec)
3739
self.assertEqual(cache.dtype, cache2.dtype)
@@ -46,6 +48,7 @@ def __init__(self):
4648
@requires_torch("2.7")
4749
@skipif_ci_windows("not working on Windows")
4850
@ignore_warnings(UserWarning)
51+
@hide_stdout()
4952
def test_exportable_mamba_cache(self):
5053
import torch
5154
from transformers.models.mamba.modeling_mamba import MambaCache
@@ -73,7 +76,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
7376
model = Model()
7477
model(x, cache)
7578

76-
with bypass_export_some_errors():
79+
with bypass_export_some_errors(replace_dynamic_cache=True, verbose=1):
7780
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
7881
torch.export.export(Model(), (x, cache))
7982

_unittests/ut_xrun_doc/test_args.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ def test_args(self):
2525
self.assertEqual(args.with_mask, 0)
2626
self.assertEqual(args.optim, "")
2727

28+
def test_args_expose(self):
29+
try:
30+
args = get_parsed_args(
31+
"plot_custom_backend_llama",
32+
config=(
33+
"medium",
34+
"large or medium depending, large means closer to the real model",
35+
),
36+
num_hidden_layers=(1, "number of hidden layers"),
37+
with_mask=(0, "tries with a mask as a secondary input"),
38+
optim=("", "Optimization to apply, empty string for all"),
39+
description="doc",
40+
new_args=["--config", "m"],
41+
expose="repeat,warmup",
42+
)
43+
except SystemExit as e:
44+
raise AssertionError(f"SystemExist caught: {e}")
45+
self.assertEqual(args.config, "m")
46+
self.assertEqual(args.num_hidden_layers, 1)
47+
self.assertEqual(args.with_mask, 0)
48+
self.assertEqual(args.optim, "")
49+
self.assertEqual(args.repeat, 10)
50+
self.assertEqual(args.warmup, 5)
51+
2852

2953
if __name__ == "__main__":
3054
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import onnx
66
import onnx.helper as oh
77
import torch
8-
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows
8+
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows, hide_stdout
99
from onnx_diagnostic.helpers import (
1010
string_type,
1111
string_sig,
@@ -127,7 +127,22 @@ def test_flatten(self):
127127
diff = max_diff(inputs, flat, flatten=True)
128128
self.assertEqual(diff["abs"], 0)
129129
d = string_diff(diff)
130-
print(d)
130+
self.assertIsInstance(d, str)
131+
132+
@hide_stdout()
133+
def test_max_diff_verbose(self):
134+
inputs = (
135+
torch.rand((3, 4), dtype=torch.float16),
136+
[
137+
torch.rand((5, 6), dtype=torch.float16),
138+
torch.rand((5, 6, 7), dtype=torch.float16),
139+
],
140+
)
141+
flat = flatten_object(inputs)
142+
diff = max_diff(inputs, flat, flatten=True, verbose=10)
143+
self.assertEqual(diff["abs"], 0)
144+
d = string_diff(diff)
145+
self.assertIsInstance(d, str)
131146

132147
def test_type_info(self):
133148
for tt in [
@@ -250,6 +265,38 @@ def test_string_type_one(self):
250265
self.assertEqual(string_type([4] * 100), "#100[int,...]")
251266
self.assertEqual(string_type((4,) * 100), "#100(int,...)")
252267

268+
def test_string_type_one_with_min_max_int(self):
269+
self.assertEqual(string_type(None, with_min_max=True), "None")
270+
self.assertEqual(string_type([4], with_min_max=True), "#1[int=4]")
271+
self.assertEqual(string_type((4, 5), with_min_max=True), "(int=4,int=5)")
272+
self.assertEqual(string_type([4] * 100, with_min_max=True), "#100[int=4,...][4,4:4.0]")
273+
self.assertEqual(
274+
string_type((4,) * 100, with_min_max=True), "#100(int=4,...)[4,4:A[4.0]]"
275+
)
276+
277+
def test_string_type_one_with_min_max_bool(self):
278+
self.assertEqual(string_type(None, with_min_max=True), "None")
279+
self.assertEqual(string_type([True], with_min_max=True), "#1[bool=True]")
280+
self.assertEqual(string_type((True, True), with_min_max=True), "(bool=True,bool=True)")
281+
self.assertEqual(
282+
string_type([True] * 100, with_min_max=True), "#100[bool=True,...][True,True:1.0]"
283+
)
284+
self.assertEqual(
285+
string_type((True,) * 100, with_min_max=True),
286+
"#100(bool=True,...)[True,True:A[1.0]]",
287+
)
288+
289+
def test_string_type_one_with_min_max_float(self):
290+
self.assertEqual(string_type(None, with_min_max=True), "None")
291+
self.assertEqual(string_type([4.5], with_min_max=True), "#1[float=4.5]")
292+
self.assertEqual(string_type((4.5, 5.5), with_min_max=True), "(float=4.5,float=5.5)")
293+
self.assertEqual(
294+
string_type([4.5] * 100, with_min_max=True), "#100[float=4.5,...][4.5,4.5:4.5]"
295+
)
296+
self.assertEqual(
297+
string_type((4.5,) * 100, with_min_max=True), "#100(float=4.5,...)[4.5,4.5:A[4.5]]"
298+
)
299+
253300
def test_string_type_at(self):
254301
self.assertEqual(string_type(None), "None")
255302
a = np.array([4, 5], dtype=np.float32)

_unittests/ut_xrun_doc/test_onnx_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx.numpy_helper as onh
55
from onnx import TensorProto
66
from onnx.checker import check_model
7-
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
88
from onnx_diagnostic.onnx_tools import (
99
onnx_lighten,
1010
onnx_unlighten,
@@ -72,6 +72,7 @@ def test_onnx_find(self):
7272
self.assertIn("xm2", res[0].output)
7373
self.assertIn("xm2", res[1].input)
7474

75+
@hide_stdout()
7576
def test__validate_function(self):
7677
new_domain = "custom"
7778

@@ -87,7 +88,7 @@ def test__validate_function(self):
8788
[oh.make_opsetid("", 14)],
8889
[],
8990
)
90-
_validate_function(linear_regression)
91+
_validate_function(linear_regression, verbose=1)
9192

9293

9394
if __name__ == "__main__":

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ def test_measure_time(self):
8484
},
8585
)
8686

87+
def test_measure_time_max(self):
88+
res = measure_time(lambda: math.cos(0.5), max_time=0.1)
89+
self.assertIsInstance(res, dict)
90+
self.assertEqual(
91+
set(res),
92+
{
93+
"min_exec",
94+
"max_exec",
95+
"average",
96+
"warmup_time",
97+
"context_size",
98+
"deviation",
99+
"repeat",
100+
"ttime",
101+
"number",
102+
},
103+
)
104+
87105

88106
if __name__ == "__main__":
89107
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,8 @@ def size_type(dtype: Any) -> int:
6666
return 4
6767
if dtype == np.float16 or dtype == np.int16:
6868
return 2
69-
if dtype == np.int16:
70-
return 2
7169
if dtype == np.int32:
7270
return 4
73-
if dtype == np.int64:
74-
return 8
7571
if dtype == np.int8:
7672
return 1
7773
if hasattr(np, "uint64"):
@@ -226,7 +222,7 @@ def string_type(
226222
)
227223
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
228224
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
229-
return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]"
225+
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
230226
return f"#{len(obj)}({tt},...)"
231227
if isinstance(obj, list):
232228
if len(obj) < limit:

0 commit comments

Comments
 (0)