1313from copy import deepcopy
1414
1515from functools import partial
16- from typing import List
16+ from typing import List , Union , Tuple
1717
1818import torch
1919import torch .fft as fft
2424
2525conv_dict = {1 : nn .Conv1d , 2 : nn .Conv2d , 3 : nn .Conv3d }
2626
27+ ACTIVATION_FUNCTIONS = [
28+ 'CELU' , 'ELU' , 'GELU' , 'GLU' , 'Hardtanh' , 'Hardshrink' , 'Hardsigmoid' ,
29+ 'Hardswish' , 'LeakyReLU' , 'LogSigmoid' , 'MultiheadAttention' , 'PReLU' ,
30+ 'ReLU' , 'ReLU6' , 'RReLU' , 'SELU' , 'SiLU' , 'Sigmoid' , 'SoftPlus' ,
31+ 'Softmax' , 'Softmax2d' , 'Softshrink' , 'Softsign' , 'Tanh' , 'Tanhshrink' ,
32+ 'Threshold' , 'Mish'
33+ ]
34+
35+ # Type hint for activation functions
36+ ActivationType = Union [str ]
37+
2738
2839class LayerNormnd (nn .GroupNorm ):
2940 """
@@ -50,28 +61,31 @@ def forward(self, v: torch.Tensor):
5061 return super ().forward (v )
5162
5263
53- class MLP (nn .Module ):
64+ class PointwiseFFN (nn .Module ):
5465 def __init__ (
5566 self ,
56- in_channels ,
57- out_channels ,
58- mid_channels ,
59- activation : str = "GELU " ,
67+ in_channels : int ,
68+ out_channels : int ,
69+ mid_channels : int ,
70+ activation : ActivationType = "ReLU " ,
6071 dim : int = 3 ,
6172 ):
62- super (MLP , self ).__init__ ()
73+ super ().__init__ ()
74+ """
75+ Pointwisely-applied 2-layer FFN with a channel expansion
76+ """
6377
6478 if dim not in conv_dict :
6579 raise ValueError (f"Unsupported dimension: { dim } , expected 1, 2, or 3" )
6680
6781 Conv = conv_dict [dim ]
68- self .mlp1 = Conv (in_channels , mid_channels , 1 )
69- self .mlp2 = Conv (mid_channels , out_channels , 1 )
82+ self .linear1 = Conv (in_channels , mid_channels , 1 )
83+ self .linear2 = Conv (mid_channels , out_channels , 1 )
7084 self .activation = getattr (nn , activation )()
7185
7286 def forward (self , v : torch .Tensor ):
73- for block in [self .mlp1 , self .activation , self .mlp2 ]:
74- v = block (v )
87+ for b in [self .linear1 , self .activation , self .linear2 ]:
88+ v = b (v )
7589 return v
7690
7791
@@ -169,13 +183,13 @@ def forward(self, v, out_mesh_size=None, **kwargs):
169183 return v
170184
171185
172- class FNO (nn .Module ):
186+ class FNOBase (nn .Module ):
173187 def __init__ (
174188 self ,
175189 * ,
176190 num_spectral_layers : int = 4 ,
177191 fft_norm = "backward" ,
178- activation : str = "ReLU" ,
192+ activation : ActivationType = "ReLU" ,
179193 spatial_padding : int = 0 ,
180194 channel_expansion : int = 4 ,
181195 spatial_random_feats : bool = False ,
@@ -199,7 +213,7 @@ def __init__(
199213
200214 self .spatial_padding = spatial_padding
201215 self .fft_norm = fft_norm
202- self .activation_name = activation
216+ self .activation = activation
203217 self .spatial_random_feats = spatial_random_feats
204218 self .lift_activation = lift_activation
205219 self .channel_expansion = channel_expansion
@@ -228,10 +242,10 @@ def _set_spectral_layers(
228242 num_layers : int ,
229243 modes : List [int ],
230244 width : int ,
231- activation : str ,
232- spectral_conv : nn . Module ,
233- mlp : nn . Module ,
234- linear : nn .Module ,
245+ activation : ActivationType ,
246+ spectral_conv : SpectralConv ,
247+ mlp : PointwiseFFN ,
248+ linear : Union [ nn .Conv1d , nn . Conv2d , nn . Conv3d ] ,
235249 channel_expansion : int = 4 ,
236250 ) -> None :
237251 """
@@ -283,8 +297,4 @@ def double(self):
283297 return self
284298
285299 def forward (self , * args , ** kwargs ):
286- """
287- if out_steps is None, it will try to use self.out_steps
288- if self.out_steps is None, it will use the temporal dimension of the input
289- """
290300 raise NotImplementedError ("Subclasses of FNO must implement the forward method" )
0 commit comments