Skip to content

Commit dccb7b7

Browse files
committed
Update README and improve typing hints, added a new loss, remove dupes
1 parent d940d1f commit dccb7b7

File tree

7 files changed

+286
-149
lines changed

7 files changed

+286
-149
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ The main changes are documented in the `README.md` under the [`torch_cfd` direct
1414
- All ops take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, n)`, which is similar to PyTorch behavior, not a single trajectory like Google's original Jax-CFD package.
1515

1616
### Part II: Spectral-Refiner: Neural Operator-Assisted Navier-Stokes Equations simulator.
17-
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for NSE. Major architectural changes can be found in [the documentation of the `SFNO` class](./fno/sfno.py#L485).
17+
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE.
18+
- Major architectural changes: learnable spatiotemporal positional encodings, layernorm to replace a hard-coded global Gaussian normalizer, and many others. For more details please see [the documentation of the `SFNO` class](./fno/sfno.py#L485).
1819
- Data generation for the meta-example of the isotropic turbulence in [McWilliams1984]. After the warmup phase, the energy spectra match the inverse cascade of Kolmogorov flow in a periodic box.
1920
- Pipelines for the *a posteriori* error estimation to fine-tune the SFNO to reach the scientific computing level of accuracy ($\le 10^{-6}$) in Bochner norm using FLOPs on par with a single evaluation, and only a fraction of FLOPs of a single `.backward()`.
2021
- [Examples](#examples) can be found below.
@@ -26,7 +27,7 @@ To install `torch-cfd`'s current release, simply do:
2627
```bash
2728
pip install torch-cfd
2829
```
29-
If one wants to play with the neural operator part, it is recommended to clone this repo and play it locally by creating a venv using `requirements.txt`. Note: using PyTorch version >=2.0.0 for the broadcasting semantics.
30+
If one wants to play with the neural operator part, it is recommended to clone this repo and play it locally by creating a venv using `requirements.txt`. Note: using PyTorch version >=2.0.0 is recommended for the broadcasting semantics.
3031
```bash
3132
python3.11 -m venv venv
3233
source venv/bin/activate
@@ -50,7 +51,7 @@ Data generation instructions are available in the [SFNO folder](./fno).
5051
- [Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb)
5152

5253
## Licenses
53-
The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of the repo that is inherited from Google's original license file for `Jax-cfd`. The `sfno` folder has the MIT license inherited from [NVIDIA's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). Note: the license(s) in the subfolder takes precedence.
54+
The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of the repo that is inherited from Google's original license file for `Jax-cfd`. The `fno` folder has the MIT license inherited from [NVIDIA's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). Note: the license(s) in the subfolder takes precedence.
5455

5556
## Contributions
5657
PR welcome. Currently, the port of `torch-cfd` currently includes:
@@ -77,4 +78,4 @@ If you like to use `torch-cfd` please use the following [paper](https://arxiv.or
7778
I am grateful for the support from [Long Chen (UC Irvine)](https://github.com/lyc102/ifem) and
7879
[Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. I also appreciate the support from the National Science Foundation (NSF) to junior researchers. I also want to thank the free A6000 credits at the SSE ML cluster from the University of Missouri.
7980

80-
For individual paper's acknowledgement please see [here](./fno/README.md).
81+
For individual paper's acknowledgment please see [here](./fno/README.md).

fno/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This is a new concise implementation of the Fourier Neural Operator see [`base.py`](./base.py#L172) for a template class.
33

44
## Learning maps between Bochner spaces
5-
SFNO now can learn a `trajectory-to-trajectory` map that inputs arbitrary-length trajectory, and outputs arbitrary-lengthed trajectory (if length is not specified, then the output length is the same with the input).
5+
SFNO now can learn a `trajectory-to-trajectory` map that inputs arbitrary-length trajectory, and outputs arbitrary-lengthed trajectory (if length is not specified, then the output length is the same with the input). The tests on its trajectory-to-trajectory shapes can be found in [`sfno_pytest.py`](sfno_pytest.py) and [`check_SFNO_shapes.py`](../examples/check_SFNO_shapes.py).
66

77
## Data generation
88

fno/base.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from copy import deepcopy
1414

1515
from functools import partial
16-
from typing import List, Union, Tuple
16+
from typing import List, Tuple, Union
1717

1818
import torch
1919
import torch.fft as fft
@@ -25,11 +25,33 @@
2525
conv_dict = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
2626

2727
ACTIVATION_FUNCTIONS = [
28-
'CELU', 'ELU', 'GELU', 'GLU', 'Hardtanh', 'Hardshrink', 'Hardsigmoid',
29-
'Hardswish', 'LeakyReLU', 'LogSigmoid', 'MultiheadAttention', 'PReLU',
30-
'ReLU', 'ReLU6', 'RReLU', 'SELU', 'SiLU', 'Sigmoid', 'SoftPlus',
31-
'Softmax', 'Softmax2d', 'Softshrink', 'Softsign', 'Tanh', 'Tanhshrink',
32-
'Threshold', 'Mish'
28+
"CELU",
29+
"ELU",
30+
"GELU",
31+
"GLU",
32+
"Hardtanh",
33+
"Hardshrink",
34+
"Hardsigmoid",
35+
"Hardswish",
36+
"LeakyReLU",
37+
"LogSigmoid",
38+
"MultiheadAttention",
39+
"PReLU",
40+
"ReLU",
41+
"ReLU6",
42+
"RReLU",
43+
"SELU",
44+
"SiLU",
45+
"Sigmoid",
46+
"SoftPlus",
47+
"Softmax",
48+
"Softmax2d",
49+
"Softshrink",
50+
"Softsign",
51+
"Tanh",
52+
"Tanhshrink",
53+
"Threshold",
54+
"Mish",
3355
]
3456

3557
# Type hint for activation functions
@@ -157,7 +179,7 @@ def complex_matmul(x, w, **kwargs):
157179
Implement this method in subclass to return complex matmul function
158180
this is a general implmentation of arbitrary dimension
159181
(b, c_i, *mesh_dims), (c_i, c_o, *mesh_dims) -> (b, c_o, *mesh_dims)
160-
for pure einsum benchmark, ellipsis version runs about 30% slower,
182+
for pure einsum benchmark, ellipsis version runs about 30% slower,
161183
however, when being implemented in FNO, the performance difference is negligible
162184
one can implement a more specific einsum for the dimension
163185
1D: (b, c_i, x), (c_i, c_o, x) -> (b, c_o, x)
@@ -166,6 +188,38 @@ def complex_matmul(x, w, **kwargs):
166188
"""
167189
return torch.einsum("bi...,io...->bo...", x, w)
168190

191+
def _set_complex_matmul_nd(self, dim: int = None):
192+
"""
193+
Generate einsum string based on dimension.
194+
1D: "bix,iox->box"
195+
2D: "bixy,ioxy->boxy"
196+
3D: "bixyz,ioxyz->boxyz"
197+
4D: "biwxyz, iowxyz->bowxyz"
198+
199+
Args:
200+
dim: The dimension of the data
201+
202+
Returns:
203+
str: The appropriate einsum string
204+
"""
205+
dim = self.dim if dim is None else dim
206+
assert dim >= 1
207+
208+
# Start with the basic components
209+
inp = "bi"
210+
w = "io"
211+
out = "bo"
212+
213+
# Add dimension-specific characters
214+
mesh_dims = "".join([chr(ord("z") + i) for i in range(1 - dim, 1)])
215+
216+
inp += mesh_dims
217+
w += mesh_dims
218+
out += mesh_dims
219+
220+
equation = f"{inp},{w}->{out}"
221+
self.complex_matmul = partial(torch.einsum, equation)
222+
169223
@abstractmethod
170224
def spectral_conv(self, vhat, *fft_mesh_size, **kwargs):
171225
raise NotImplementedError(

fno/data_gen/solvers.py

Lines changed: 54 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,37 @@
11
import math
2+
from datetime import datetime
23
from functools import partial
34
from typing import Callable, Tuple, Union
4-
from datetime import datetime
55

66
import torch
77
import torch.fft as fft
8-
import torch.nn as nn
98
import torch.nn.functional as F
10-
from einops import pack, rearrange, repeat
9+
from einops import repeat
1110
from torch.linalg import norm
12-
1311
from torch_cfd.equations import *
12+
1413
TQDM_ITERS = 200
1514

15+
# TODO
16+
# [x] removes dupes with torch_cfd module
1617

17-
def backdiff(x, order:int=3):
18+
19+
def backdiff(x, order: int = 3):
1820
"""
1921
bdf scheme for x: (b, *, x, y, t)
2022
"""
21-
bdf_weights ={
23+
if order > 5:
24+
raise NotImplementedError("only bdf order <= 5 is implemented")
25+
bdf_weights = {
2226
1: [1, -1],
23-
2: [3/2, -2, 0.5],
24-
3: [11/6, -3, 3/2, -1/3],
25-
4: [25/12, -4, 3, -4/3, 1/4],
26-
5: [137/60, -5, 5, -10/3, 5/4, -1/5]
27+
2: [3 / 2, -2, 0.5],
28+
3: [11 / 6, -3, 3 / 2, -1 / 3],
29+
4: [25 / 12, -4, 3, -4 / 3, 1 / 4],
30+
5: [137 / 60, -5, 5, -10 / 3, 5 / 4, -1 / 5],
2731
}
2832
weights = torch.as_tensor(bdf_weights[order]).to(x.device)
29-
x_t = x[...,-(order+1):].flip(-1)*weights
33+
x_t = x[..., -(order + 1) :].flip(-1) * weights
3034
return x_t.sum(-1)
31-
32-
def fft_mesh_2d(n, diam, device=None):
33-
kx, ky = [fft.fftfreq(n, d=diam/n) for _ in range(2)]
34-
kx, ky = torch.meshgrid([kx, ky], indexing="ij")
35-
return kx.to(device), ky.to(device)
36-
37-
def fft_expand_dims(fft_mesh, batch_size):
38-
kx, ky = fft_mesh
39-
kx, ky = [repeat(z, "x y -> b x y 1", b=batch_size) for z in [kx, ky]]
40-
return kx, ky
41-
42-
def spectral_div_2d(vhat, fft_mesh):
43-
r"""
44-
Computes the 2D divergence in the Fourier basis.
45-
TODO: this is a dupe function with torch_cfd module
46-
needed cleaning up and some refactoring
47-
"""
48-
uhat, vhat = vhat
49-
kx, ky = fft_mesh
50-
return 2j * torch.pi * (uhat * kx + vhat * ky)
51-
52-
def spectral_grad_2d(vhat, rfft_mesh):
53-
kx, ky = rfft_mesh
54-
return 2j * torch.pi * kx * vhat, 2j * torch.pi * ky * vhat
55-
56-
def spectral_laplacian_2d(fft_mesh, device=None):
57-
"""
58-
TODO: this is a dupe function with torch_cfd module
59-
"""
60-
kx, ky = fft_mesh
61-
lap = -4 * (torch.pi**2) * (abs(kx) ** 2 + abs(ky) ** 2)
62-
lap[..., 0, 0] = 1
63-
return lap.to(device)
64-
65-
def get_freq_spacetime(n, n_t=None, delta_t=None, device=None):
66-
n_t = n if n_t is None else n_t
67-
delta_t = 1 / n_t if delta_t is None else delta_t
68-
kx = fft.fftfreq(n, d=1 / n)
69-
ky = fft.fftfreq(n, d=1 / n)
70-
kt = fft.fftfreq(n_t, d=delta_t)
71-
kx, ky, kt = torch.meshgrid([kx, ky, kt], indexing="ij")
72-
return kx.to(device), ky.to(device), kt.to(device)
73-
74-
def spectral_laplacian_spacetime(n, n_t=None, device=None):
75-
kx, ky, _ = get_freq_spacetime(n, n_t)
76-
lap = -4 * (torch.pi**2) * (kx**2 + ky**2)
77-
lap[0, 0] = 1
78-
return lap.to(device)
7935

8036

8137
def interp2d(x, **kwargs):
@@ -89,9 +45,17 @@ def interp2d(x, **kwargs):
8945
x = x.unsqueeze(0)
9046
return F.interpolate(x, **kwargs).squeeze()
9147

48+
9249
def update_residual(
93-
w_h, w_h_t, f_h, visc, rfftmesh, laplacian,
94-
dealias_filter=None, dealias=True, **kwargs
50+
w_h,
51+
w_h_t,
52+
f_h,
53+
visc,
54+
rfftmesh,
55+
laplacian,
56+
dealias_filter=None,
57+
dealias=True,
58+
**kwargs,
9559
):
9660
"""
9761
compute the residual of an input w in the frequency domain
@@ -136,7 +100,7 @@ def imex_crank_nicolson_step(
136100
dealias: bool = False,
137101
output_rfft: bool = False,
138102
debug=False,
139-
**kwargs
103+
**kwargs,
140104
):
141105
"""
142106
inputs:
@@ -160,7 +124,6 @@ def imex_crank_nicolson_step(
160124
n = size[-2]
161125
k_max = math.floor(n / 2.0)
162126

163-
164127
if rfftmesh is None:
165128
kx = fft.fftfreq(n, d=diam / n)
166129
ky = fft.fftfreq(n, d=diam / n)
@@ -179,17 +142,18 @@ def imex_crank_nicolson_step(
179142
laplacian[..., 0, 0] = 1.0
180143

181144
if dealias_filter is None:
182-
dealias_filter = (
183-
torch.logical_and(
184-
torch.abs(ky) <= (2.0 / 3.0) * k_max,
185-
torch.abs(kx) <= (2.0 / 3.0) * k_max,
186-
))
145+
dealias_filter = torch.logical_and(
146+
torch.abs(ky) <= (2.0 / 3.0) * k_max,
147+
torch.abs(kx) <= (2.0 / 3.0) * k_max,
148+
)
187149

188150
if f.ndim < w.ndim:
189151
f = f.unsqueeze(0)
190152

191153
# Stream function in Fourier space: solve Poisson equation
192-
psi_h = -w / laplacian # valid for w: (b, *, n, n//2+1, n_t) and lap: (n, n//2+1, n_t)
154+
psi_h = (
155+
-w / laplacian
156+
) # valid for w: (b, *, n, n//2+1, n_t) and lap: (n, n//2+1, n_t)
193157

194158
# Velocity field in x-direction = psi_y
195159
u = 2 * math.pi * ky * 1j * psi_h
@@ -222,17 +186,18 @@ def imex_crank_nicolson_step(
222186
return w_next, dwdt, w, psi_h, res_h, (kx, ky), laplacian, dealias_filter
223187
else:
224188
return w_next, dwdt, w, psi_h, res_h
225-
189+
190+
226191
def get_trajectory_rk4(
227192
equation: ImplicitExplicitODE,
228193
w0: Array,
229194
dt: float,
230195
num_steps: int = 1,
231196
record_every_steps: int = 1,
232-
pbar=False,
233-
pbar_desc="generating trajectories using RK4",
234-
require_grad=False,
235-
dtype=torch.complex64,
197+
pbar: bool = False,
198+
pbar_desc: str = "generating trajectories using RK4",
199+
require_grad: bool = False,
200+
dtype: torch.dtype = torch.complex64,
236201
):
237202
"""
238203
vorticity stacked in the time dimension
@@ -253,14 +218,14 @@ def get_trajectory_rk4(
253218
w = w0
254219
n = w0.size(-1)
255220
tqdm_iters = num_steps if TQDM_ITERS > num_steps else TQDM_ITERS
256-
update_iters = num_steps // tqdm_iters
221+
update_every_iters = num_steps // tqdm_iters
257222
with tqdm(total=num_steps, disable=not pbar) as pb:
258223
for t_step in range(num_steps):
259224
w, dwdt = equation.forward(w, dt=dt)
260225
w.requires_grad_(require_grad)
261226
dwdt.requires_grad_(require_grad)
262227

263-
if t_step % update_iters == 0:
228+
if t_step % update_every_iters == 0:
264229
res = equation.residual(w, dwdt)
265230
res_ = fft.irfft2(res).real
266231
w_ = fft.irfft2(w).real
@@ -275,7 +240,7 @@ def get_trajectory_rk4(
275240
+ res_desc
276241
)
277242
pb.set_description(desc)
278-
pb.update(update_iters)
243+
pb.update(update_every_iters)
279244

280245
if t_step % record_every_steps == 0:
281246
_, psi = vorticity_to_velocity(equation.grid, w)
@@ -303,15 +268,15 @@ def get_trajectory_rk4(
303268
def get_trajectory_imex_crank_nicolson(
304269
w0,
305270
f,
306-
visc=1e-3,
307-
T=1,
308-
delta_t=1e-3,
309-
record_steps=1,
310-
diam=1,
311-
dealias=True,
312-
subsample=1,
313-
dtype=None,
314-
pbar=True,
271+
visc: float = 1e-3,
272+
T: float = 1,
273+
delta_t: float = 1e-3,
274+
record_steps: int = 1,
275+
diam: float = 1,
276+
dealias: bool = True,
277+
subsample: int = 1,
278+
dtype: torch.dtype = None,
279+
pbar: bool = True,
315280
**kwargs,
316281
):
317282
"""
@@ -482,13 +447,13 @@ def get_trajectory_imex_crank_nicolson(
482447
t_steps=t_steps,
483448
)
484449

450+
485451
if __name__ == "__main__":
486452
n = 256
487453
bsz = 4
488-
w = torch.randn(bsz, n, n//2+1).to(torch.complex128)
489-
f = torch.randn(n, n//2+1).to(torch.complex128)
454+
w = torch.randn(bsz, n, n // 2 + 1).to(torch.complex128)
455+
f = torch.randn(n, n // 2 + 1).to(torch.complex128)
490456
result = imex_crank_nicolson_step(w, f, 1e-3, 1e-3)
491457
for v in result:
492458
if isinstance(v, torch.Tensor):
493459
print(v.shape, v.dtype, v.device)
494-

0 commit comments

Comments
 (0)