Skip to content

Commit 374538b

Browse files
committed
updtae string_diff
1 parent 69a42c4 commit 374538b

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`92`: support errors distribution in max_diff
78
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
89
* :pr:`88`, :pr:`89`: extends ``steal_forward`` to dump input, outputs in onnx models
910
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)

_unittests/ut_helpers/test_helper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,23 @@ def test_max_diff_hist_array(self):
229229
},
230230
)
231231

232+
def test_max_diff_hist_array_string_diff(self):
233+
x = np.arange(12).reshape((3, 4)).astype(dtype=np.float32)
234+
y = x.copy()
235+
y[0, 1] += 0.1
236+
y[0, 2] += 0.01
237+
y[0, 3] += 0.001
238+
y[1, 1] += 0.0001
239+
y[1, 2] += 1
240+
y[2, 2] += 10
241+
y[1, 3] += 100
242+
y[2, 1] += 1000
243+
diff = max_diff(x, y, hist=True)
244+
s = string_diff(diff)
245+
self.assertEndsWith(
246+
"/#8>0.0-#8>0.0001-#6>0.001-#5>0.01-#5>0.1-#3>1.0-#2>10.0-#1>100.0", s
247+
)
248+
232249
def test_max_diff_hist_tensor(self):
233250
x = torch.arange(12).reshape((3, 4)).to(dtype=torch.float32)
234251
y = x.clone()

onnx_diagnostic/ext_test_case.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,11 @@ def assertStartsWith(self, prefix: str, full: str):
10701070
if not full.startswith(prefix):
10711071
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
10721072

1073+
def assertEndsWith(self, suffix: str, full: str):
1074+
"""In the name"""
1075+
if not full.endswith(suffix):
1076+
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1077+
10731078
def capture(self, fct: Callable):
10741079
"""
10751080
Runs a function and capture standard output and error.

onnx_diagnostic/helpers/helper.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,10 +1458,20 @@ def max_diff(
14581458
def string_diff(diff: Dict[str, Any]) -> str:
14591459
"""Renders discrepancies return by :func:`max_diff` into one string."""
14601460
# dict(abs=, rel=, sum=, n=n_diff, dnan=)
1461+
suffix = ""
1462+
if "rep" in diff:
1463+
rows = []
1464+
for k, v in diff["rep"].items():
1465+
if v > 0:
1466+
rows.append(f"#{v}{k}")
1467+
suffix = "-".join(rows)
1468+
suffix = f"/{suffix}"
14611469
if diff.get("dnan", None):
14621470
if diff["abs"] == 0 or diff["rel"] == 0:
1463-
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}"
1464-
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}"
1471+
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}"
1472+
return (
1473+
f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}{suffix}"
1474+
)
14651475
if diff["abs"] == 0 or diff["rel"] == 0:
1466-
return f"abs={diff['abs']}, rel={diff['rel']}"
1467-
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}"
1476+
return f"abs={diff['abs']}, rel={diff['rel']}{suffix}"
1477+
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}{suffix}"

0 commit comments

Comments
 (0)