|
1 | 1 | import contextlib |
2 | 2 | import ctypes |
3 | 3 | import inspect |
| 4 | +import math |
4 | 5 | import os |
5 | 6 | import sys |
6 | 7 | import warnings |
@@ -1003,3 +1004,74 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype: |
1003 | 1004 | counts[dt] += 1 |
1004 | 1005 | final = max(list(counts.items())) |
1005 | 1006 | return final[0] |
| 1007 | + |
| 1008 | + |
| 1009 | +def closest_factor_pair(n: int): |
| 1010 | + """Tries to find ``a, b`` such as ``n == a * b``.""" |
| 1011 | + assert n > 0, f"n={n} must be a positive integer" |
| 1012 | + start = math.isqrt(n) |
| 1013 | + for a in range(start, 0, -1): |
| 1014 | + if n % a == 0: |
| 1015 | + b = n // a |
| 1016 | + return a, b |
| 1017 | + return 1, n |
| 1018 | + |
| 1019 | + |
| 1020 | +def study_discrepancies( |
| 1021 | + t1: torch.Tensor, |
| 1022 | + t2: torch.Tensor, |
| 1023 | + bins: int = 50, |
| 1024 | + figsize: Optional[Tuple[int, int]] = (15, 15), |
| 1025 | + title: Optional[str] = None, |
| 1026 | + name: Optional[str] = None, |
| 1027 | +) -> "Axes": # noqa: F821 |
| 1028 | + """ |
| 1029 | + Computes different metrics for the discrepancies. |
| 1030 | + Returns graphs. |
| 1031 | + """ |
| 1032 | + assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}" |
| 1033 | + assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}" |
| 1034 | + d1, d2 = ( |
| 1035 | + (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32)) |
| 1036 | + ) |
| 1037 | + |
| 1038 | + d1 = d1.squeeze() |
| 1039 | + d2 = d2.squeeze() |
| 1040 | + if len(d1.shape) == 1: |
| 1041 | + new_shape = closest_factor_pair(d1.shape[0]) |
| 1042 | + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) |
| 1043 | + elif len(d1.shape) > 2: |
| 1044 | + new_shape = (-1, max(d1.shape)) |
| 1045 | + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) |
| 1046 | + |
| 1047 | + import matplotlib.pyplot as plt |
| 1048 | + |
| 1049 | + fig, ax = plt.subplots(3, 2, figsize=figsize) |
| 1050 | + vmin, vmax = d1.min().item(), d1.max().item() |
| 1051 | + ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax) |
| 1052 | + ax[0, 0].set_title(f"Color plot of the first tensor in\n[{vmin}, {vmax}]") |
| 1053 | + |
| 1054 | + diff = d2 - d1 |
| 1055 | + vmin, vmax = diff.min().item(), diff.max().item() |
| 1056 | + ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax) |
| 1057 | + ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]") |
| 1058 | + |
| 1059 | + ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins) |
| 1060 | + ax[1, 0].set_title("Distribution of the first tensor") |
| 1061 | + |
| 1062 | + ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins) |
| 1063 | + ax[1, 1].set_title("Distribution of the differences") |
| 1064 | + |
| 1065 | + tf1 = d1.ravel() |
| 1066 | + td1 = diff.ravel() |
| 1067 | + ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".") |
| 1068 | + ax[2, 1].set_title("Graph XY") |
| 1069 | + ax[2, 1].set_xlabel("First tensor values") |
| 1070 | + ax[2, 1].set_ylabel("Difference values") |
| 1071 | + |
| 1072 | + if title: |
| 1073 | + fig.suptitle(title) |
| 1074 | + fig.tight_layout() |
| 1075 | + if name: |
| 1076 | + fig.savefig(name) |
| 1077 | + return ax |
0 commit comments