Skip to content

Commit 8715162

Browse files
committed
style
1 parent 67af115 commit 8715162

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import ctypes
33
import inspect
4+
import math
45
import os
56
import sys
67
import warnings
@@ -1003,3 +1004,74 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
10031004
counts[dt] += 1
10041005
final = max(list(counts.items()))
10051006
return final[0]
1007+
1008+
1009+
def closest_factor_pair(n: int):
1010+
"""Tries to find ``a, b`` such as ``n == a * b``."""
1011+
assert n > 0, f"n={n} must be a positive integer"
1012+
start = math.isqrt(n)
1013+
for a in range(start, 0, -1):
1014+
if n % a == 0:
1015+
b = n // a
1016+
return a, b
1017+
return 1, n
1018+
1019+
1020+
def study_discrepancies(
1021+
t1: torch.Tensor,
1022+
t2: torch.Tensor,
1023+
bins: int = 50,
1024+
figsize: Optional[Tuple[int, int]] = (15, 15),
1025+
title: Optional[str] = None,
1026+
name: Optional[str] = None,
1027+
) -> "Axes": # noqa: F821
1028+
"""
1029+
Computes different metrics for the discrepancies.
1030+
Returns graphs.
1031+
"""
1032+
assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
1033+
assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
1034+
d1, d2 = (
1035+
(t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
1036+
)
1037+
1038+
d1 = d1.squeeze()
1039+
d2 = d2.squeeze()
1040+
if len(d1.shape) == 1:
1041+
new_shape = closest_factor_pair(d1.shape[0])
1042+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1043+
elif len(d1.shape) > 2:
1044+
new_shape = (-1, max(d1.shape))
1045+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1046+
1047+
import matplotlib.pyplot as plt
1048+
1049+
fig, ax = plt.subplots(3, 2, figsize=figsize)
1050+
vmin, vmax = d1.min().item(), d1.max().item()
1051+
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1052+
ax[0, 0].set_title(f"Color plot of the first tensor in\n[{vmin}, {vmax}]")
1053+
1054+
diff = d2 - d1
1055+
vmin, vmax = diff.min().item(), diff.max().item()
1056+
ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
1057+
ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
1058+
1059+
ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
1060+
ax[1, 0].set_title("Distribution of the first tensor")
1061+
1062+
ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
1063+
ax[1, 1].set_title("Distribution of the differences")
1064+
1065+
tf1 = d1.ravel()
1066+
td1 = diff.ravel()
1067+
ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
1068+
ax[2, 1].set_title("Graph XY")
1069+
ax[2, 1].set_xlabel("First tensor values")
1070+
ax[2, 1].set_ylabel("Difference values")
1071+
1072+
if title:
1073+
fig.suptitle(title)
1074+
fig.tight_layout()
1075+
if name:
1076+
fig.savefig(name)
1077+
return ax

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import textwrap
44
import time
55
from dataclasses import dataclass
6-
from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union
6+
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
7+
8+
try:
9+
from typing import Self
10+
except ImportError:
11+
# python <= 3.10
12+
Self = "Self"
713
import onnx
814
import onnx.helper as oh
915
import numpy as np
@@ -246,6 +252,7 @@ def get_replay_code(self) -> str:
246252
import onnx
247253
import torch
248254
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
255+
from onnx_diagnostic.helpers.torch_helper import study_discrepancies
249256
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
250257
from onnx_diagnostic.reference import OnnxruntimeEvaluator
251258
@@ -294,6 +301,16 @@ def get_replay_code(self) -> str:
294301
print(f"-- obtained={string_type(obtained, **skws)}")
295302
diff = max_diff(expected, tuple(obtained))
296303
print(f"-- diff: {string_diff(diff)}")
304+
305+
print("-- plots")
306+
for i in range(len(expected)):
307+
study_discrepancies(
308+
expected[i],
309+
obtained[i],
310+
title=f"study output {i}",
311+
name=f"disc{i}.png",
312+
bins=50,
313+
)
297314
"""
298315
)
299316

0 commit comments

Comments
 (0)