Skip to content

Commit 33beb03

Browse files
committed
extend coverage
1 parent fd6092f commit 33beb03

File tree

5 files changed

+50
-8
lines changed

5 files changed

+50
-8
lines changed

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
requires_transformers,
66
skipif_ci_windows,
77
ignore_warnings,
8+
hide_stdout,
89
)
910
from onnx_diagnostic.helpers import string_type
1011
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -16,6 +17,7 @@ class TestOnnxExportErrors(ExtTestCase):
1617
@requires_transformers("4.49.999")
1718
@skipif_ci_windows("not working on Windows")
1819
@ignore_warnings(UserWarning)
20+
@hide_stdout()
1921
def test_pytree_flatten_mamba_cache(self):
2022
import torch
2123
import torch.utils._pytree as py_pytree
@@ -31,7 +33,7 @@ def __init__(self):
3133

3234
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
3335

34-
with bypass_export_some_errors():
36+
with bypass_export_some_errors(verbose=1):
3537
values, spec = py_pytree.tree_flatten(cache)
3638
cache2 = py_pytree.tree_unflatten(values, spec)
3739
self.assertEqual(cache.dtype, cache2.dtype)
@@ -46,6 +48,7 @@ def __init__(self):
4648
@requires_torch("2.7")
4749
@skipif_ci_windows("not working on Windows")
4850
@ignore_warnings(UserWarning)
51+
@hide_stdout()
4952
def test_exportable_mamba_cache(self):
5053
import torch
5154
from transformers.models.mamba.modeling_mamba import MambaCache
@@ -73,7 +76,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
7376
model = Model()
7477
model(x, cache)
7578

76-
with bypass_export_some_errors():
79+
with bypass_export_some_errors(replace_dynamic_cache=True, verbose=1):
7780
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
7881
torch.export.export(Model(), (x, cache))
7982

_unittests/ut_xrun_doc/test_args.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ def test_args(self):
2525
self.assertEqual(args.with_mask, 0)
2626
self.assertEqual(args.optim, "")
2727

28+
def test_args_expose(self):
29+
try:
30+
args = get_parsed_args(
31+
"plot_custom_backend_llama",
32+
config=(
33+
"medium",
34+
"large or medium depending, large means closer to the real model",
35+
),
36+
num_hidden_layers=(1, "number of hidden layers"),
37+
with_mask=(0, "tries with a mask as a secondary input"),
38+
optim=("", "Optimization to apply, empty string for all"),
39+
description="doc",
40+
new_args=["--config", "m"],
41+
expose="repeat,warmup",
42+
)
43+
except SystemExit as e:
44+
raise AssertionError(f"SystemExist caught: {e}")
45+
self.assertEqual(args.config, "m")
46+
self.assertEqual(args.num_hidden_layers, 1)
47+
self.assertEqual(args.with_mask, 0)
48+
self.assertEqual(args.optim, "")
49+
self.assertEqual(args.repeat, 10)
50+
self.assertEqual(args.warmup, 5)
51+
2852

2953
if __name__ == "__main__":
3054
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_onnx_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx.numpy_helper as onh
55
from onnx import TensorProto
66
from onnx.checker import check_model
7-
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
88
from onnx_diagnostic.onnx_tools import (
99
onnx_lighten,
1010
onnx_unlighten,
@@ -72,6 +72,7 @@ def test_onnx_find(self):
7272
self.assertIn("xm2", res[0].output)
7373
self.assertIn("xm2", res[1].input)
7474

75+
@hide_stdout()
7576
def test__validate_function(self):
7677
new_domain = "custom"
7778

@@ -87,7 +88,7 @@ def test__validate_function(self):
8788
[oh.make_opsetid("", 14)],
8889
[],
8990
)
90-
_validate_function(linear_regression)
91+
_validate_function(linear_regression, verbose=1)
9192

9293

9394
if __name__ == "__main__":

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ def test_measure_time(self):
8484
},
8585
)
8686

87+
def test_measure_time_max(self):
88+
res = measure_time(lambda: math.cos(0.5), max_time=0.1)
89+
self.assertIsInstance(res, dict)
90+
self.assertEqual(
91+
set(res),
92+
{
93+
"min_exec",
94+
"max_exec",
95+
"average",
96+
"warmup_time",
97+
"context_size",
98+
"deviation",
99+
"repeat",
100+
"ttime",
101+
"number",
102+
},
103+
)
104+
87105

88106
if __name__ == "__main__":
89107
unittest.main(verbosity=2)

onnx_diagnostic/helpers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,8 @@ def size_type(dtype: Any) -> int:
6666
return 4
6767
if dtype == np.float16 or dtype == np.int16:
6868
return 2
69-
if dtype == np.int16:
70-
return 2
7169
if dtype == np.int32:
7270
return 4
73-
if dtype == np.int64:
74-
return 8
7571
if dtype == np.int8:
7672
return 1
7773
if hasattr(np, "uint64"):

0 commit comments

Comments
 (0)