Skip to content

Commit 2db1ba1

Browse files
authored
improve max_diff (#315)
1 parent fa18368 commit 2db1ba1

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

_doc/technical/plot_gemm_or_matmul_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
What an operator Gemm in :epkg:`onnxruntime`, the most simple
1111
way to represent a linear neural layer.
1212
13-
A model with three choices
14-
==========================
13+
A model with many choices
14+
=========================
1515
"""
1616

1717
import cpuinfo

onnx_diagnostic/helpers/helper.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,19 +1486,27 @@ def max_diff(
14861486
dev=dev,
14871487
)
14881488
if hist:
1489-
if isinstance(hist, bool):
1490-
hist = torch.tensor(
1491-
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1492-
)
1493-
hist = hist.to(diff.device)
1494-
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1495-
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1496-
res["rep"] = dict(
1497-
zip(
1498-
[f">{x}" for x in hist],
1499-
[int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1489+
if isinstance(hist, list) and len(hist) == 1:
1490+
res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1491+
elif isinstance(hist, list) and len(hist) == 2:
1492+
res["rep"] = {
1493+
f">{hist[0]}": (diff > hist[0]).sum().item(),
1494+
f">{hist[1]}": (diff > hist[1]).sum().item(),
1495+
}
1496+
else:
1497+
if isinstance(hist, bool):
1498+
hist = torch.tensor(
1499+
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1500+
)
1501+
hist = torch.tensor(hist).to(diff.device)
1502+
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1503+
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1504+
res["rep"] = dict(
1505+
zip(
1506+
[f">{x}" for x in hist],
1507+
[int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1508+
)
15001509
)
1501-
)
15021510
return res # type: ignore
15031511

15041512
if isinstance(expected, int) and isinstance(got, torch.Tensor):

0 commit comments

Comments
 (0)