Skip to content

Commit 9588104

Browse files
committed
add argmax
1 parent 7936215 commit 9588104

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ def max_diff(
11811181
if exp_cpu.size == got_cpu.size
11821182
else (np.inf, np.inf, np.inf, 0, np.inf)
11831183
)
1184+
argm = None
11841185
else:
11851186
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
11861187
float(diff.max()),
@@ -1189,6 +1190,7 @@ def max_diff(
11891190
float(diff.size),
11901191
float(ndiff.sum()),
11911192
)
1193+
argm = tuple(map(int, np.unravel_index(diff.argmax(), diff.shape)))
11921194
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
11931195
# To understand the value it comes from.
11941196
if debug_info:
@@ -1219,7 +1221,9 @@ def max_diff(
12191221
f"_index={_index}"
12201222
)
12211223

1222-
res = dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff)
1224+
res = dict(
1225+
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1226+
)
12231227
if hist:
12241228
if isinstance(hist, bool):
12251229
hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
@@ -1284,8 +1288,10 @@ def max_diff(
12841288
float(diff.numel()),
12851289
float(ndiff.sum()),
12861290
)
1291+
argm = tuple(map(int, torch.unravel_index(diff.argmax(), diff.shape)))
12871292
elif got_cpu.numel() == exp_cpu.numel():
12881293
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0)
1294+
argm = None
12891295
else:
12901296
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
12911297
np.inf,
@@ -1294,6 +1300,7 @@ def max_diff(
12941300
np.inf,
12951301
np.inf,
12961302
)
1303+
argm = None
12971304

12981305
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
12991306
# To understand the value it comes from.
@@ -1325,7 +1332,9 @@ def max_diff(
13251332
f"_index={_index}"
13261333
)
13271334

1328-
res = dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff)
1335+
res = dict(
1336+
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1337+
)
13291338
if hist:
13301339
if isinstance(hist, bool):
13311340
hist = torch.tensor(
@@ -1478,6 +1487,8 @@ def string_diff(diff: Dict[str, Any]) -> str:
14781487
rows.append(f"#{v}{k}")
14791488
suffix = "-".join(rows)
14801489
suffix = f"/{suffix}"
1490+
if "argm" in diff:
1491+
suffix += f", argmax={diff['argm']}"
14811492
if diff.get("dnan", None):
14821493
if diff["abs"] == 0 or diff["rel"] == 0:
14831494
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}"

0 commit comments

Comments
 (0)