Skip to content

Commit db67007

Browse files
committed
fix
1 parent 33beb03 commit db67007

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import onnx
66
import onnx.helper as oh
77
import torch
8-
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows
8+
from onnx_diagnostic.ext_test_case import ExtTestCase, skipif_ci_windows, hide_stdout
99
from onnx_diagnostic.helpers import (
1010
string_type,
1111
string_sig,
@@ -127,7 +127,22 @@ def test_flatten(self):
127127
diff = max_diff(inputs, flat, flatten=True)
128128
self.assertEqual(diff["abs"], 0)
129129
d = string_diff(diff)
130-
print(d)
130+
self.assertIsInstance(d, str)
131+
132+
@hide_stdout()
133+
def test_max_diff_verbose(self):
134+
inputs = (
135+
torch.rand((3, 4), dtype=torch.float16),
136+
[
137+
torch.rand((5, 6), dtype=torch.float16),
138+
torch.rand((5, 6, 7), dtype=torch.float16),
139+
],
140+
)
141+
flat = flatten_object(inputs)
142+
diff = max_diff(inputs, flat, flatten=True, verbose=10)
143+
self.assertEqual(diff["abs"], 0)
144+
d = string_diff(diff)
145+
self.assertIsInstance(d, str)
131146

132147
def test_type_info(self):
133148
for tt in [
@@ -250,6 +265,38 @@ def test_string_type_one(self):
250265
self.assertEqual(string_type([4] * 100), "#100[int,...]")
251266
self.assertEqual(string_type((4,) * 100), "#100(int,...)")
252267

268+
def test_string_type_one_with_min_max_int(self):
269+
self.assertEqual(string_type(None, with_min_max=True), "None")
270+
self.assertEqual(string_type([4], with_min_max=True), "#1[int=4]")
271+
self.assertEqual(string_type((4, 5), with_min_max=True), "(int=4,int=5)")
272+
self.assertEqual(string_type([4] * 100, with_min_max=True), "#100[int=4,...][4,4:4.0]")
273+
self.assertEqual(
274+
string_type((4,) * 100, with_min_max=True), "#100(int=4,...)[4,4:A[4.0]]"
275+
)
276+
277+
def test_string_type_one_with_min_max_bool(self):
278+
self.assertEqual(string_type(None, with_min_max=True), "None")
279+
self.assertEqual(string_type([True], with_min_max=True), "#1[bool=True]")
280+
self.assertEqual(string_type((True, True), with_min_max=True), "(bool=True,bool=True)")
281+
self.assertEqual(
282+
string_type([True] * 100, with_min_max=True), "#100[bool=True,...][True,True:1.0]"
283+
)
284+
self.assertEqual(
285+
string_type((True,) * 100, with_min_max=True),
286+
"#100(bool=True,...)[True,True:A[1.0]]",
287+
)
288+
289+
def test_string_type_one_with_min_max_float(self):
290+
self.assertEqual(string_type(None, with_min_max=True), "None")
291+
self.assertEqual(string_type([4.5], with_min_max=True), "#1[float=4.5]")
292+
self.assertEqual(string_type((4.5, 5.5), with_min_max=True), "(float=4.5,float=5.5)")
293+
self.assertEqual(
294+
string_type([4.5] * 100, with_min_max=True), "#100[float=4.5,...][4.5,4.5:4.5]"
295+
)
296+
self.assertEqual(
297+
string_type((4.5,) * 100, with_min_max=True), "#100(float=4.5,...)[4.5,4.5:A[4.5]]"
298+
)
299+
253300
def test_string_type_at(self):
254301
self.assertEqual(string_type(None), "None")
255302
a = np.array([4, 5], dtype=np.float32)

onnx_diagnostic/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def string_type(
222222
)
223223
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
224224
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
225-
return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]"
225+
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
226226
return f"#{len(obj)}({tt},...)"
227227
if isinstance(obj, list):
228228
if len(obj) < limit:

0 commit comments

Comments
 (0)