Skip to content

Commit da83cea

Browse files
committed
fix git
1 parent c4bfe8c commit da83cea

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_max_diff_hist_array_string_diff(self):
245245
diff = max_diff(x, y, hist=True)
246246
s = string_diff(diff)
247247
self.assertEndsWith(
248-
"/#8>0.0-#8>0.0001-#6>0.001-#5>0.01-#5>0.1-#3>1.0-#2>10.0-#1>100.0", s
248+
"/#8>0.0-#8>0.0001-#6>0.001-#5>0.01-#5>0.1-#3>1.0-#2>10.0-#1>100.0,amax=2,1", s
249249
)
250250

251251
def test_max_diff_hist_tensor(self):

onnx_diagnostic/helpers/helper.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import enum
33
import inspect
44
from dataclasses import is_dataclass, fields
5-
from typing import Any, Callable, Dict, List, Optional, Set, Union
5+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
66
import numpy as np
77

88

@@ -872,7 +872,7 @@ def max_diff(
872872
_index: int = 0,
873873
allow_unique_tensor_with_list_of_one_element: bool = True,
874874
hist: Optional[Union[bool, List[float]]] = None,
875-
) -> Dict[str, float]:
875+
) -> Dict[str, Union[float, str, int, Tuple[int, ...]]]:
876876
"""
877877
Returns the maximum discrepancy.
878878
@@ -1221,7 +1221,7 @@ def max_diff(
12211221
f"_index={_index}"
12221222
)
12231223

1224-
res = dict(
1224+
res: Dict[str, Union[str, int, float, Tuple[int, ...]]] = dict(
12251225
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
12261226
)
12271227
if hist:
@@ -1332,7 +1332,7 @@ def max_diff(
13321332
f"_index={_index}"
13331333
)
13341334

1335-
res = dict(
1335+
res: Dict[str, Union[str, int, float, Tuple[int, ...]]] = dict(
13361336
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
13371337
)
13381338
if hist:
@@ -1488,7 +1488,12 @@ def string_diff(diff: Dict[str, Any]) -> str:
14881488
suffix = "-".join(rows)
14891489
suffix = f"/{suffix}"
14901490
if "argm" in diff:
1491-
suffix += f", argmax={diff['argm']}"
1491+
sa = (
1492+
",".join(map(str, diff["argm"]))
1493+
if isinstance(diff["argm"], tuple)
1494+
else str(diff["argm"])
1495+
)
1496+
suffix += f",amax={sa}"
14921497
if diff.get("dnan", None):
14931498
if diff["abs"] == 0 or diff["rel"] == 0:
14941499
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}"

0 commit comments

Comments
 (0)