Skip to content

Commit 57c920b

Browse files
authored
More unit tests (#167)
* More unit tests * mypy * mypy
1 parent af71ef2 commit 57c920b

File tree

5 files changed

+65
-16
lines changed

5 files changed

+65
-16
lines changed

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import unittest
33
from contextlib import redirect_stdout
44
from io import StringIO
5-
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
66
from onnx_diagnostic._command_lines_parser import main
7+
from onnx_diagnostic.helpers.log_helper import enumerate_csv_files
78

89

910
class TestCommandLines(ExtTestCase):
@@ -66,6 +67,20 @@ def test_parser_validate(self):
6667
text = st.getvalue()
6768
self.assertIn("model_clas", text)
6869

70+
@ignore_warnings(UserWarning)
71+
def test_parser_agg(self):
72+
path = os.path.abspath(
73+
os.path.join(os.path.dirname(__file__), "..", "ut_helpers", "data")
74+
)
75+
assert list(enumerate_csv_files([f"{path}/*.zip"]))
76+
output = self.get_dump_file("test_parser_agg.xlsx")
77+
st = StringIO()
78+
with redirect_stdout(st):
79+
main(["agg", output, f"{path}/*.zip", "--filter", ".*.csv", "-v", "1"])
80+
text = st.getvalue()
81+
self.assertIn("[CubeLogs.to_excel] plots 1 plots", text)
82+
self.assertExists(output)
83+
6984

7085
if __name__ == "__main__":
7186
unittest.main(verbosity=2)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.doc import reset_torch_transformers
4+
5+
6+
class TestDocDoc(ExtTestCase):
7+
8+
def test_reset(self):
9+
reset_torch_transformers(None, None)
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
has_cuda,
1717
has_onnxscript,
1818
)
19+
from onnx_diagnostic.api import TensorLike
1920

2021

2122
class TestUnitTest(ExtTestCase):
@@ -110,6 +111,10 @@ def test_measure_time_max(self):
110111
},
111112
)
112113

114+
def test_exc(self):
115+
self.assertRaise(lambda: TensorLike().dtype, NotImplementedError)
116+
self.assertRaise(lambda: TensorLike().shape, NotImplementedError)
117+
113118

114119
if __name__ == "__main__":
115120
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _cmd_agg(argv: List[Any]):
765765
args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
766766
)
767767
)
768-
assert csv, f"No csv files in {args.inputs}, csv={csv}"
768+
assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
769769
if args.verbose:
770770
from tqdm import tqdm
771771

onnx_diagnostic/helpers/log_helper.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,11 @@ def to_excel(
12531253
df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
12541254

12551255
for name, view in views.items():
1256+
if view is None:
1257+
continue
12561258
df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
1259+
if tview is None:
1260+
continue
12571261
memory = df.memory_usage(deep=True).sum()
12581262
if verbose:
12591263
print(
@@ -1654,10 +1658,12 @@ def unbiased_export(df):
16541658

16551659
def view(
16561660
self,
1657-
view_def: Union[str, CubeViewDef],
1661+
view_def: Optional[Union[str, CubeViewDef]],
16581662
return_view_def: bool = False,
16591663
verbose: int = 0,
1660-
) -> Union[pandas.DataFrame, Tuple[pandas.DataFrame, CubeViewDef]]:
1664+
) -> Union[
1665+
Optional[pandas.DataFrame], Tuple[Optional[pandas.DataFrame], Optional[CubeViewDef]]
1666+
]:
16611667
"""
16621668
Returns a dataframe, a pivot view.
16631669
@@ -1666,18 +1672,22 @@ def view(
16661672
:param view_def: view definition or a string
16671673
:param return_view_def: returns the view definition as well
16681674
:param verbose: verbosity level
1669-
:return: dataframe
1675+
:return: dataframe or a couple (dataframe, view definition),
1676+
both of them can be one if view_def cannot be interpreted
16701677
"""
1678+
assert view_def is not None, "view_def is None, this is not allowed."
16711679
if isinstance(view_def, str):
16721680
view_def = self.make_view_def(view_def)
1681+
if view_def is None:
1682+
return (None, None) if return_view_def else None
16731683
return super().view(view_def, return_view_def=return_view_def, verbose=verbose)
16741684

1675-
def make_view_def(self, name: str) -> CubeViewDef:
1685+
def make_view_def(self, name: str) -> Optional[CubeViewDef]:
16761686
"""
16771687
Returns a view definition.
16781688
16791689
:param name: name of the view
1680-
:return: a CubeViewDef
1690+
:return: a CubeViewDef or None if name does not make sense
16811691
16821692
Available views:
16831693
@@ -1892,14 +1902,6 @@ def mean_geo(gr):
18921902
f_highlight=f_bucket,
18931903
order=order,
18941904
),
1895-
"cmd": lambda: CubeViewDef(
1896-
key_index=index_cols,
1897-
values=self._filter_column(["CMD"], self.values),
1898-
ignore_unique=True,
1899-
keep_columns_in_index=["suite"],
1900-
name="cmd",
1901-
order=order,
1902-
),
19031905
"onnx": lambda: CubeViewDef(
19041906
key_index=index_cols,
19051907
values=self._filter_column(
@@ -1927,11 +1929,25 @@ def mean_geo(gr):
19271929
no_index=True,
19281930
),
19291931
}
1930-
assert name in implemented_views, (
1932+
1933+
cmd_col = self._filter_column(["CMD"], self.values, can_be_empty=True)
1934+
if cmd_col:
1935+
implemented_views["cmd"] = lambda: CubeViewDef(
1936+
key_index=index_cols,
1937+
values=cmd_col,
1938+
ignore_unique=True,
1939+
keep_columns_in_index=["suite"],
1940+
name="cmd",
1941+
order=order,
1942+
)
1943+
1944+
assert name in implemented_views or name in {"cmd"}, (
19311945
f"Unknown view {name!r}, expected a name in {sorted(implemented_views)},"
19321946
f"\n--\nkeys={pprint.pformat(sorted(self.keys_time))}, "
19331947
f"\n--\nvalues={pprint.pformat(sorted(self.values))}"
19341948
)
1949+
if name not in implemented_views:
1950+
return None
19351951
return implemented_views[name]()
19361952

19371953
def post_load_process_piece(

0 commit comments

Comments
 (0)