Skip to content

Commit 53fa6ed

Browse files
authored
Fixes a few inefficiencies in sbs (#319)
* fix a few inefficiencies in sbs * fix * better doc * to avoid any wrong commit * sbs * fix disc * style * fxi * changes * fix * study
1 parent 1d7735f commit 53fa6ed

File tree

7 files changed

+333
-36
lines changed

7 files changed

+333
-36
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ _cache/*
5151
.coverage
5252
dist/*
5353
build/*
54+
_sbs_*
5455
.eggs/*
5556
.olive-cache/*
5657
.hypothesis/*

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Change Logs
88
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
99
* :pr:`310`: splits patches into multiple files
1010
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
11-
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs
11+
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`, :pr:`319`: improves side-by-side comparison, creates command line sbs
1212

1313
0.8.2
1414
+++++

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
3030
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
31-
from onnx_diagnostic.helpers.torch_helper import to_tensor
31+
from onnx_diagnostic.helpers.torch_helper import to_tensor, study_discrepancies
3232

3333
TFLOAT = onnx.TensorProto.FLOAT
3434

@@ -425,6 +425,12 @@ def test_get_weight_type(self):
425425
dt = get_weight_type(model)
426426
self.assertEqual(torch.float32, dt)
427427

428+
def test_study_discrepancies(self):
429+
t1 = torch.rand((3, 4))
430+
t2 = torch.rand((3, 4))
431+
ax = study_discrepancies(t1, t2)
432+
self.assertEqual(ax.shape, ((3, 2)))
433+
428434

429435
if __name__ == "__main__":
430436
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ def get_parser_sbs() -> ArgumentParser:
12341234
"--replay-threshold",
12351235
type=float,
12361236
required=False,
1237-
default=1e6,
1237+
default=1e9,
12381238
help="Triggers the replay if the discrepancies are higher than this value.",
12391239
)
12401240
parser.add_argument(

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import ctypes
33
import inspect
4+
import math
45
import os
56
import sys
67
import warnings
@@ -1003,3 +1004,76 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
10031004
counts[dt] += 1
10041005
final = max(list(counts.items()))
10051006
return final[0]
1007+
1008+
1009+
def closest_factor_pair(n: int):
1010+
"""Tries to find ``a, b`` such as ``n == a * b``."""
1011+
assert n > 0, f"n={n} must be a positive integer"
1012+
start = math.isqrt(n)
1013+
for a in range(start, 0, -1):
1014+
if n % a == 0:
1015+
b = n // a
1016+
return a, b
1017+
return 1, n
1018+
1019+
1020+
def study_discrepancies(
1021+
t1: torch.Tensor,
1022+
t2: torch.Tensor,
1023+
bins: int = 50,
1024+
figsize: Optional[Tuple[int, int]] = (15, 15),
1025+
title: Optional[str] = None,
1026+
name: Optional[str] = None,
1027+
) -> "matplotlib.axes.Axes": # noqa: F821
1028+
"""
1029+
Computes different metrics for the discrepancies.
1030+
Returns graphs.
1031+
"""
1032+
assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
1033+
assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
1034+
d1, d2 = (
1035+
(t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
1036+
)
1037+
1038+
d1 = d1.squeeze()
1039+
d2 = d2.squeeze()
1040+
if len(d1.shape) == 1:
1041+
new_shape = closest_factor_pair(d1.shape[0])
1042+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1043+
elif len(d1.shape) > 2:
1044+
new_shape = (-1, max(d1.shape))
1045+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1046+
1047+
import matplotlib.pyplot as plt
1048+
1049+
fig, ax = plt.subplots(3, 2, figsize=figsize)
1050+
vmin, vmax = d1.min().item(), d1.max().item()
1051+
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1052+
ax[0, 0].set_title(
1053+
f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
1054+
)
1055+
1056+
diff = d2 - d1
1057+
vmin, vmax = diff.min().item(), diff.max().item()
1058+
ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
1059+
ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
1060+
1061+
ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
1062+
ax[1, 0].set_title("Distribution of the first tensor")
1063+
1064+
ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
1065+
ax[1, 1].set_title("Distribution of the differences")
1066+
1067+
tf1 = d1.ravel()
1068+
td1 = diff.ravel()
1069+
ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
1070+
ax[2, 1].set_title("Graph XY")
1071+
ax[2, 1].set_xlabel("First tensor values")
1072+
ax[2, 1].set_ylabel("Difference values")
1073+
1074+
if title:
1075+
fig.suptitle(title)
1076+
fig.tight_layout()
1077+
if name:
1078+
fig.savefig(name)
1079+
return ax

0 commit comments

Comments
 (0)