Skip to content

Commit d940d1f

Browse files
committed
adding types hints and a pytest
1 parent 579a764 commit d940d1f

File tree

6 files changed

+486
-141
lines changed

6 files changed

+486
-141
lines changed

examples/check_SFNO_shapes.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
from fno.sfno import SFNO
3+
4+
5+
if __name__ == "__main__":
6+
"""
7+
testing the arbitrary sizes inference for both
8+
spatial and temporal dimensions of SFNO
9+
"""
10+
modes = 8
11+
modes_t = 2
12+
width = 10
13+
bsz = 5
14+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15+
sizes = [(n, n, n_t) for (n, n_t) in zip([64, 128, 256], [5, 10, 20])]
16+
model = SFNO(modes, modes, modes_t, width,
17+
latent_steps=3).to(device)
18+
x = torch.randn(bsz, *sizes[0]).to(device)
19+
_ = model(x)
20+
21+
try:
22+
from torchinfo import summary
23+
24+
"""
25+
torchinfo has not resolve the complex number problem
26+
"""
27+
summary(model, input_size=(bsz, *sizes[-1]))
28+
except:
29+
raise ImportError(
30+
"torchinfo is not installed, please install it to get the model summary"
31+
)
32+
del model
33+
34+
print("\n" * 3)
35+
for k, size in enumerate(sizes):
36+
torch.cuda.empty_cache()
37+
model = SFNO(modes, modes, modes_t, width, latent_steps=3).to(device)
38+
model.add_latent_hook("activations")
39+
x = torch.randn(bsz, *size).to(device)
40+
pred = model(x)
41+
print(f"\n\ninput shape: {list(x.size())}")
42+
print(f"output shape: {list(pred.size())}")
43+
for k, v in model.latent_tensors.items():
44+
print(k, list(v.shape))
45+
del model
46+
47+
print("\n")
48+
# test evaluation speed
49+
from time import time
50+
51+
torch.cuda.empty_cache()
52+
model = SFNO(modes, modes, modes_t, width, latent_steps=3).to(device)
53+
model.eval()
54+
x = torch.randn(bsz, *sizes[1]).to(device)
55+
start_time = time()
56+
for _ in range(100):
57+
pred = model(x)
58+
end_time = time()
59+
print(f"Average eval for time: {(end_time - start_time) / 100:.6f} seconds")
60+
del model

fno/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ Generate the isotropic turbulence in [1] with the inverse cascade frequency sign
3838

3939
## Training and evaluation scripts
4040

41+
### VSCode workspace for development
42+
Please add the following setting to your VSCode workspace setting:
43+
```json
44+
"settings": {
45+
"terminal.integrated.env.osx": {"PYTHONPATH": "${workspaceFolder}"},
46+
"terminal.integrated.env.linux": {"PYTHONPATH": "${workspaceFolder}"},
47+
"jupyter.notebookFileRoot": "${workspaceFolder}",
48+
}
49+
```
50+
4151

4252
### Testing the arbitrary input and output discretization sizes (including time)
4353
Run the part below `__name__ == "__main__"` in [`sfno.py`](sfno.py)

fno/base.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from copy import deepcopy
1414

1515
from functools import partial
16-
from typing import List
16+
from typing import List, Union, Tuple
1717

1818
import torch
1919
import torch.fft as fft
@@ -24,6 +24,17 @@
2424

2525
conv_dict = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
2626

27+
ACTIVATION_FUNCTIONS = [
28+
'CELU', 'ELU', 'GELU', 'GLU', 'Hardtanh', 'Hardshrink', 'Hardsigmoid',
29+
'Hardswish', 'LeakyReLU', 'LogSigmoid', 'MultiheadAttention', 'PReLU',
30+
'ReLU', 'ReLU6', 'RReLU', 'SELU', 'SiLU', 'Sigmoid', 'SoftPlus',
31+
'Softmax', 'Softmax2d', 'Softshrink', 'Softsign', 'Tanh', 'Tanhshrink',
32+
'Threshold', 'Mish'
33+
]
34+
35+
# Type hint for activation functions
36+
ActivationType = Union[str]
37+
2738

2839
class LayerNormnd(nn.GroupNorm):
2940
"""
@@ -50,28 +61,31 @@ def forward(self, v: torch.Tensor):
5061
return super().forward(v)
5162

5263

53-
class MLP(nn.Module):
64+
class PointwiseFFN(nn.Module):
5465
def __init__(
5566
self,
56-
in_channels,
57-
out_channels,
58-
mid_channels,
59-
activation: str = "GELU",
67+
in_channels: int,
68+
out_channels: int,
69+
mid_channels: int,
70+
activation: ActivationType = "ReLU",
6071
dim: int = 3,
6172
):
62-
super(MLP, self).__init__()
73+
super().__init__()
74+
"""
75+
Pointwisely-applied 2-layer FFN with a channel expansion
76+
"""
6377

6478
if dim not in conv_dict:
6579
raise ValueError(f"Unsupported dimension: {dim}, expected 1, 2, or 3")
6680

6781
Conv = conv_dict[dim]
68-
self.mlp1 = Conv(in_channels, mid_channels, 1)
69-
self.mlp2 = Conv(mid_channels, out_channels, 1)
82+
self.linear1 = Conv(in_channels, mid_channels, 1)
83+
self.linear2 = Conv(mid_channels, out_channels, 1)
7084
self.activation = getattr(nn, activation)()
7185

7286
def forward(self, v: torch.Tensor):
73-
for block in [self.mlp1, self.activation, self.mlp2]:
74-
v = block(v)
87+
for b in [self.linear1, self.activation, self.linear2]:
88+
v = b(v)
7589
return v
7690

7791

@@ -169,13 +183,13 @@ def forward(self, v, out_mesh_size=None, **kwargs):
169183
return v
170184

171185

172-
class FNO(nn.Module):
186+
class FNOBase(nn.Module):
173187
def __init__(
174188
self,
175189
*,
176190
num_spectral_layers: int = 4,
177191
fft_norm="backward",
178-
activation: str = "ReLU",
192+
activation: ActivationType = "ReLU",
179193
spatial_padding: int = 0,
180194
channel_expansion: int = 4,
181195
spatial_random_feats: bool = False,
@@ -199,7 +213,7 @@ def __init__(
199213

200214
self.spatial_padding = spatial_padding
201215
self.fft_norm = fft_norm
202-
self.activation_name = activation
216+
self.activation = activation
203217
self.spatial_random_feats = spatial_random_feats
204218
self.lift_activation = lift_activation
205219
self.channel_expansion = channel_expansion
@@ -228,10 +242,10 @@ def _set_spectral_layers(
228242
num_layers: int,
229243
modes: List[int],
230244
width: int,
231-
activation: str,
232-
spectral_conv: nn.Module,
233-
mlp: nn.Module,
234-
linear: nn.Module,
245+
activation: ActivationType,
246+
spectral_conv: SpectralConv,
247+
mlp: PointwiseFFN,
248+
linear: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d],
235249
channel_expansion: int = 4,
236250
) -> None:
237251
"""
@@ -283,8 +297,4 @@ def double(self):
283297
return self
284298

285299
def forward(self, *args, **kwargs):
286-
"""
287-
if out_steps is None, it will try to use self.out_steps
288-
if self.out_steps is None, it will use the temporal dimension of the input
289-
"""
290300
raise NotImplementedError("Subclasses of FNO must implement the forward method")

0 commit comments

Comments
 (0)