Skip to content

Commit 2cc47e5

Browse files
committed
fixed a bug when align_shapes has (N, T, n, n) inputs
1 parent cf5aff7 commit 2cc47e5

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

fno/datasets.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class UnitGaussianNormalizer(nn.Module):
2323
def __init__(
2424
self,
25-
eps=1e-5,
25+
eps=1e-7,
2626
data: Union[torch.Tensor, np.ndarray] = None,
2727
):
2828
super().__init__()
@@ -91,18 +91,16 @@ def forward(self, *args, **kwargs):
9191
@staticmethod
9292
def _align_shapes(x, mean, std, **kwargs):
9393
"""
94-
x: (N, n, n, C) or (N, n, n)
94+
x: (bsz, m, m, C) or (bsz, m, m) or (bsz, C, m, m)
95+
mean: (n, n, C) or (n, n) or (C, n, n)
9596
"""
97+
# print(x.shape)
9698
_, *size = x.shape
97-
h_x, w_x = size[0], size[1]
98-
if h_x != mean.shape[0] or w_x != mean.shape[1]:
99-
mean = mean.permute(2, 0, 1)
100-
std = std.permute(2, 0, 1)
101-
mean = F.interpolate(mean[None, ...], size=(h_x, w_x), **kwargs)
102-
std = F.interpolate(std[None, ...], size=(h_x, w_x), **kwargs)
103-
mean = mean.permute(0, 2, 3, 1).squeeze(0)
104-
std = std.permute(0, 2, 3, 1).squeeze(0)
105-
return mean, std
99+
if len(size) != mean.ndim or any([s != m for s, m in zip(size, mean.shape)]):
100+
mean = F.interpolate(mean[None, None, ...], size=size, **kwargs)
101+
std = F.interpolate(std[None, None, ...], size=size, **kwargs)
102+
103+
return mean.squeeze(), std.squeeze()
106104

107105

108106
class SpatialGaussianNormalizer(UnitGaussianNormalizer):

0 commit comments

Comments
 (0)