Skip to content

Commit b633f51

Browse files
committed
More unit tests
1 parent 28fe237 commit b633f51

File tree

5 files changed

+57
-11
lines changed

5 files changed

+57
-11
lines changed

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import unittest
33
from contextlib import redirect_stdout
44
from io import StringIO
5-
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
66
from onnx_diagnostic._command_lines_parser import main
7+
from onnx_diagnostic.helpers.log_helper import enumerate_csv_files
78

89

910
class TestCommandLines(ExtTestCase):
@@ -66,6 +67,20 @@ def test_parser_validate(self):
6667
text = st.getvalue()
6768
self.assertIn("model_clas", text)
6869

70+
@ignore_warnings(UserWarning)
71+
def test_parser_agg(self):
72+
path = os.path.abspath(
73+
os.path.join(os.path.dirname(__file__), "..", "ut_helpers", "data")
74+
)
75+
assert list(enumerate_csv_files([f"{path}/*.zip"]))
76+
output = self.get_dump_file("test_parser_agg.xlsx")
77+
st = StringIO()
78+
with redirect_stdout(st):
79+
main(["agg", output, f"{path}/*.zip", "--filter", ".*.csv", "-v", "1"])
80+
text = st.getvalue()
81+
self.assertIn("[CubeLogs.to_excel] plots 1 plots", text)
82+
self.assertExists(output)
83+
6984

7085
if __name__ == "__main__":
7186
unittest.main(verbosity=2)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.doc import reset_torch_transformers
4+
5+
6+
class TestDocDoc(ExtTestCase):
7+
8+
def test_reset(self):
9+
reset_torch_transformers(None, None)
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
has_cuda,
1717
has_onnxscript,
1818
)
19+
from onnx_diagnostic.api import TensorLike
1920

2021

2122
class TestUnitTest(ExtTestCase):
@@ -110,6 +111,10 @@ def test_measure_time_max(self):
110111
},
111112
)
112113

114+
def test_exc(self):
115+
self.assertRaise(lambda: TensorLike().dtype, NotImplementedError)
116+
self.assertRaise(lambda: TensorLike().shape, NotImplementedError)
117+
113118

114119
if __name__ == "__main__":
115120
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _cmd_agg(argv: List[Any]):
765765
args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
766766
)
767767
)
768-
assert csv, f"No csv files in {args.inputs}, csv={csv}"
768+
assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
769769
if args.verbose:
770770
from tqdm import tqdm
771771

onnx_diagnostic/helpers/log_helper.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,11 @@ def to_excel(
12531253
df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
12541254

12551255
for name, view in views.items():
1256+
if view is None:
1257+
continue
12561258
df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
1259+
if tview is None:
1260+
continue
12571261
memory = df.memory_usage(deep=True).sum()
12581262
if verbose:
12591263
print(
@@ -1668,8 +1672,11 @@ def view(
16681672
:param verbose: verbosity level
16691673
:return: dataframe
16701674
"""
1675+
assert view_def is not None, "view_def is None, this is not allowed."
16711676
if isinstance(view_def, str):
16721677
view_def = self.make_view_def(view_def)
1678+
if view_def is None:
1679+
return (None, None) if return_view_def else None
16731680
return super().view(view_def, return_view_def=return_view_def, verbose=verbose)
16741681

16751682
def make_view_def(self, name: str) -> CubeViewDef:
@@ -1892,14 +1899,6 @@ def mean_geo(gr):
18921899
f_highlight=f_bucket,
18931900
order=order,
18941901
),
1895-
"cmd": lambda: CubeViewDef(
1896-
key_index=index_cols,
1897-
values=self._filter_column(["CMD"], self.values),
1898-
ignore_unique=True,
1899-
keep_columns_in_index=["suite"],
1900-
name="cmd",
1901-
order=order,
1902-
),
19031902
"onnx": lambda: CubeViewDef(
19041903
key_index=index_cols,
19051904
values=self._filter_column(
@@ -1927,11 +1926,25 @@ def mean_geo(gr):
19271926
no_index=True,
19281927
),
19291928
}
1930-
assert name in implemented_views, (
1929+
1930+
cmd_col = self._filter_column(["CMD"], self.values, can_be_empty=True)
1931+
if cmd_col:
1932+
implemented_views["cmd"] = lambda: CubeViewDef(
1933+
key_index=index_cols,
1934+
values=cmd_col,
1935+
ignore_unique=True,
1936+
keep_columns_in_index=["suite"],
1937+
name="cmd",
1938+
order=order,
1939+
)
1940+
1941+
assert name in implemented_views or name in {"cmd"}, (
19311942
f"Unknown view {name!r}, expected a name in {sorted(implemented_views)},"
19321943
f"\n--\nkeys={pprint.pformat(sorted(self.keys_time))}, "
19331944
f"\n--\nvalues={pprint.pformat(sorted(self.values))}"
19341945
)
1946+
if name not in implemented_views:
1947+
return None
19351948
return implemented_views[name]()
19361949

19371950
def post_load_process_piece(

0 commit comments

Comments
 (0)