|
2 | 2 | import enum |
3 | 3 | import inspect |
4 | 4 | from dataclasses import is_dataclass, fields |
5 | | -from typing import Any, Callable, Dict, List, Optional, Set, Union |
| 5 | +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
6 | 6 | import numpy as np |
7 | 7 |
|
8 | 8 |
|
@@ -872,7 +872,7 @@ def max_diff( |
872 | 872 | _index: int = 0, |
873 | 873 | allow_unique_tensor_with_list_of_one_element: bool = True, |
874 | 874 | hist: Optional[Union[bool, List[float]]] = None, |
875 | | -) -> Dict[str, float]: |
| 875 | +) -> Dict[str, Union[float, str, int, Tuple[int, ...]]]: |
876 | 876 | """ |
877 | 877 | Returns the maximum discrepancy. |
878 | 878 |
|
@@ -1221,7 +1221,7 @@ def max_diff( |
1221 | 1221 | f"_index={_index}" |
1222 | 1222 | ) |
1223 | 1223 |
|
1224 | | - res = dict( |
| 1224 | + res: Dict[str, Union[str, int, float, Tuple[int, ...]]] = dict( |
1225 | 1225 | abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm |
1226 | 1226 | ) |
1227 | 1227 | if hist: |
@@ -1332,7 +1332,7 @@ def max_diff( |
1332 | 1332 | f"_index={_index}" |
1333 | 1333 | ) |
1334 | 1334 |
|
1335 | | - res = dict( |
| 1335 | + res: Dict[str, Union[str, int, float, Tuple[int, ...]]] = dict( |
1336 | 1336 | abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm |
1337 | 1337 | ) |
1338 | 1338 | if hist: |
@@ -1488,7 +1488,12 @@ def string_diff(diff: Dict[str, Any]) -> str: |
1488 | 1488 | suffix = "-".join(rows) |
1489 | 1489 | suffix = f"/{suffix}" |
1490 | 1490 | if "argm" in diff: |
1491 | | - suffix += f", argmax={diff['argm']}" |
| 1491 | + sa = ( |
| 1492 | + ",".join(map(str, diff["argm"])) |
| 1493 | + if isinstance(diff["argm"], tuple) |
| 1494 | + else str(diff["argm"]) |
| 1495 | + ) |
| 1496 | + suffix += f",amax={sa}" |
1492 | 1497 | if diff.get("dnan", None): |
1493 | 1498 | if diff["abs"] == 0 or diff["rel"] == 0: |
1494 | 1499 | return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}" |
|
0 commit comments