|
5 | 5 | import onnx |
6 | 6 | import onnx.helper as oh |
7 | 7 | import torch |
8 | | -from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows |
| 8 | +from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows, hide_stdout |
9 | 9 | from onnx_diagnostic.helpers import ( |
10 | 10 | string_type, |
11 | 11 | string_sig, |
@@ -127,7 +127,22 @@ def test_flatten(self): |
127 | 127 | diff = max_diff(inputs, flat, flatten=True) |
128 | 128 | self.assertEqual(diff["abs"], 0) |
129 | 129 | d = string_diff(diff) |
130 | | - print(d) |
| 130 | + self.assertIsInstance(d, str) |
| 131 | + |
| 132 | + @hide_stdout() |
| 133 | + def test_max_diff_verbose(self): |
| 134 | + inputs = ( |
| 135 | + torch.rand((3, 4), dtype=torch.float16), |
| 136 | + [ |
| 137 | + torch.rand((5, 6), dtype=torch.float16), |
| 138 | + torch.rand((5, 6, 7), dtype=torch.float16), |
| 139 | + ], |
| 140 | + ) |
| 141 | + flat = flatten_object(inputs) |
| 142 | + diff = max_diff(inputs, flat, flatten=True, verbose=10) |
| 143 | + self.assertEqual(diff["abs"], 0) |
| 144 | + d = string_diff(diff) |
| 145 | + self.assertIsInstance(d, str) |
131 | 146 |
|
132 | 147 | def test_type_info(self): |
133 | 148 | for tt in [ |
@@ -250,6 +265,38 @@ def test_string_type_one(self): |
250 | 265 | self.assertEqual(string_type([4] * 100), "#100[int,...]") |
251 | 266 | self.assertEqual(string_type((4,) * 100), "#100(int,...)") |
252 | 267 |
|
| 268 | + def test_string_type_one_with_min_max_int(self): |
| 269 | + self.assertEqual(string_type(None, with_min_max=True), "None") |
| 270 | + self.assertEqual(string_type([4], with_min_max=True), "#1[int=4]") |
| 271 | + self.assertEqual(string_type((4, 5), with_min_max=True), "(int=4,int=5)") |
| 272 | + self.assertEqual(string_type([4] * 100, with_min_max=True), "#100[int=4,...][4,4:4.0]") |
| 273 | + self.assertEqual( |
| 274 | + string_type((4,) * 100, with_min_max=True), "#100(int=4,...)[4,4:A[4.0]]" |
| 275 | + ) |
| 276 | + |
| 277 | + def test_string_type_one_with_min_max_bool(self): |
| 278 | + self.assertEqual(string_type(None, with_min_max=True), "None") |
| 279 | + self.assertEqual(string_type([True], with_min_max=True), "#1[bool=True]") |
| 280 | + self.assertEqual(string_type((True, True), with_min_max=True), "(bool=True,bool=True)") |
| 281 | + self.assertEqual( |
| 282 | + string_type([True] * 100, with_min_max=True), "#100[bool=True,...][True,True:1.0]" |
| 283 | + ) |
| 284 | + self.assertEqual( |
| 285 | + string_type((True,) * 100, with_min_max=True), |
| 286 | + "#100(bool=True,...)[True,True:A[1.0]]", |
| 287 | + ) |
| 288 | + |
| 289 | + def test_string_type_one_with_min_max_float(self): |
| 290 | + self.assertEqual(string_type(None, with_min_max=True), "None") |
| 291 | + self.assertEqual(string_type([4.5], with_min_max=True), "#1[float=4.5]") |
| 292 | + self.assertEqual(string_type((4.5, 5.5), with_min_max=True), "(float=4.5,float=5.5)") |
| 293 | + self.assertEqual( |
| 294 | + string_type([4.5] * 100, with_min_max=True), "#100[float=4.5,...][4.5,4.5:4.5]" |
| 295 | + ) |
| 296 | + self.assertEqual( |
| 297 | + string_type((4.5,) * 100, with_min_max=True), "#100(float=4.5,...)[4.5,4.5:A[4.5]]" |
| 298 | + ) |
| 299 | + |
253 | 300 | def test_string_type_at(self): |
254 | 301 | self.assertEqual(string_type(None), "None") |
255 | 302 | a = np.array([4, 5], dtype=np.float32) |
|
0 commit comments