|
22 | 22 | class UnitGaussianNormalizer(nn.Module): |
23 | 23 | def __init__( |
24 | 24 | self, |
25 | | - eps=1e-5, |
| 25 | + eps=1e-7, |
26 | 26 | data: Union[torch.Tensor, np.ndarray] = None, |
27 | 27 | ): |
28 | 28 | super().__init__() |
@@ -91,18 +91,16 @@ def forward(self, *args, **kwargs): |
91 | 91 | @staticmethod |
92 | 92 | def _align_shapes(x, mean, std, **kwargs): |
93 | 93 | """ |
94 | | - x: (N, n, n, C) or (N, n, n) |
| 94 | + x: (bsz, m, m, C) or (bsz, m, m) or (bsz, C, m, m) |
| 95 | + mean: (n, n, C) or (n, n) or (C, n, n) |
95 | 96 | """ |
| 97 | + # print(x.shape) |
96 | 98 | _, *size = x.shape |
97 | | - h_x, w_x = size[0], size[1] |
98 | | - if h_x != mean.shape[0] or w_x != mean.shape[1]: |
99 | | - mean = mean.permute(2, 0, 1) |
100 | | - std = std.permute(2, 0, 1) |
101 | | - mean = F.interpolate(mean[None, ...], size=(h_x, w_x), **kwargs) |
102 | | - std = F.interpolate(std[None, ...], size=(h_x, w_x), **kwargs) |
103 | | - mean = mean.permute(0, 2, 3, 1).squeeze(0) |
104 | | - std = std.permute(0, 2, 3, 1).squeeze(0) |
105 | | - return mean, std |
| 99 | + if len(size) != mean.ndim or any([s != m for s, m in zip(size, mean.shape)]): |
| 100 | + mean = F.interpolate(mean[None, None, ...], size=size, **kwargs) |
| 101 | + std = F.interpolate(std[None, None, ...], size=size, **kwargs) |
| 102 | + |
| 103 | + return mean.squeeze(), std.squeeze() |
106 | 104 |
|
107 | 105 |
|
108 | 106 | class SpatialGaussianNormalizer(UnitGaussianNormalizer): |
|
0 commit comments