Skip to content

Commit 8dbb1d5

Browse files
committed
Changed forcing to an nn.Module template class
1 parent 20608ac commit 8dbb1d5

File tree

5 files changed

+268
-64
lines changed

5 files changed

+268
-64
lines changed

example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Lines changed: 21 additions & 17 deletions
Large diffs are not rendered by default.

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy==1.24.4
2-
torch>=2.0.1
2+
torch>=2.2.0
33
xarray>=2023.1.0
4+
tqdm>=4.62.0

torch_cfd/equations.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Modifications copyright (C) 2024 S.Cao
1616
# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717

18-
from typing import Callable, Dict, Optional
18+
from typing import Tuple, Callable, Dict, Optional
1919

2020
import torch
2121
import torch.nn as nn
@@ -148,7 +148,7 @@ def crank_nicolson_rk4(
148148
)
149149

150150

151-
class NavierStokes2D(nn.Module):
151+
class NavierStokes2DSpectral(nn.Module):
152152
"""Breaks the Navier-Stokes equation into implicit and explicit parts.
153153
154154
Implicit parts are the linear terms and explicit parts are the non-linear
@@ -201,14 +201,24 @@ def brick_wall_filter_2d(grid: Grid):
201201
return filter_
202202

203203
@staticmethod
204-
def spectral_curl_2d(mesh, velocity_hat):
205-
"""Computes the 2D curl in the Fourier basis."""
206-
kx, ky = mesh
207-
uhat, vhat = velocity_hat
204+
def spectral_curl_2d(vhat, rfft_mesh):
205+
r"""
206+
Computes the 2D curl in the Fourier basis.
207+
det [d_x d_y \\ u v]
208+
"""
209+
uhat, vhat = vhat
210+
kx, ky = rfft_mesh
208211
return 2j * torch.pi * (vhat * kx - uhat * ky)
209212

210213
@staticmethod
211-
def vorticity_to_velocity(grid: Grid, w_hat: Array):
214+
def spectral_grad_2d(vhat, rfft_mesh):
215+
kx, ky = rfft_mesh
216+
return 2j * torch.pi * kx * vhat, 2j * torch.pi * ky * vhat
217+
218+
@staticmethod
219+
def vorticity_to_velocity(
220+
grid: Grid, w_hat: Array, rfft_mesh: Optional[Tuple[Array, Array]] = None
221+
):
212222
"""Constructs a function for converting vorticity to velocity, both in Fourier domain.
213223
214224
Solves for the stream function and then uses the stream function to compute
@@ -229,19 +239,24 @@ def vorticity_to_velocity(grid: Grid, w_hat: Array):
229239
Pages 509-520, ISSN 0045-7930,
230240
https://doi.org/10.1016/j.compfluid.2003.06.003.
231241
"""
232-
device = w_hat.device
233-
kx, ky = grid.rfft_mesh()
234-
kx, ky = kx.to(device), ky.to(device)
242+
kx, ky = rfft_mesh if rfft_mesh is not None else grid.rfft_mesh()
235243
two_pi_i = 2 * torch.pi * 1j
236244
laplace = two_pi_i**2 * (abs(kx) ** 2 + abs(ky) ** 2)
237245
laplace[0, 0] = 1
238246
psi_hat = -1 / laplace * w_hat
239247
vxhat = two_pi_i * ky * psi_hat
240248
vyhat = -two_pi_i * kx * psi_hat
241249
return vxhat, vyhat
250+
251+
def residual(self,
252+
vort_hat: Array,
253+
vort_t_hat: Array,
254+
):
255+
residual = vort_t_hat - self.explicit_terms(vort_hat) - self.viscosity * self.implicit_terms(vort_hat)
256+
return residual
242257

243258
def _explicit_terms(self, vort_hat):
244-
vxhat, vyhat = self.vorticity_to_velocity(self.grid, vort_hat)
259+
vxhat, vyhat = self.vorticity_to_velocity(self.grid, vort_hat, (self.kx, self.ky))
245260
vx, vy = fft.irfft2(vxhat), fft.irfft2(vyhat)
246261

247262
grad_x_hat = 2j * torch.pi * self.kx * vort_hat
@@ -251,15 +266,15 @@ def _explicit_terms(self, vort_hat):
251266
advection = -(grad_x * vx + grad_y * vy)
252267
advection_hat = fft.rfft2(advection)
253268

254-
if self.smooth is not None:
269+
if self.smooth:
255270
advection_hat *= self.filter
256271

257272
terms = advection_hat
258273

259274
if self.forcing_fn is not None:
260275
fx, fy = self.forcing_fn(self.grid, (vx, vy))
261276
fx_hat, fy_hat = fft.rfft2(fx.data), fft.rfft2(fy.data)
262-
terms += self.spectral_curl_2d((self.kx, self.ky), (fx_hat, fy_hat))
277+
terms += self.spectral_curl_2d((fx_hat, fy_hat), (self.kx, self.ky))
263278

264279
return terms
265280

@@ -276,7 +291,7 @@ def get_trajectory(
276291
self,
277292
w0: Array,
278293
dt: float,
279-
time_steps: int,
294+
T: float,
280295
record_every_steps=1,
281296
pbar=False,
282297
pbar_desc="",
@@ -288,7 +303,9 @@ def get_trajectory(
288303
w_all = []
289304
v_all = []
290305
dwdt_all = []
306+
res_all = []
291307
w = w0
308+
time_steps = int(T / dt)
292309
update_iters = time_steps // TQDM_ITERS
293310
with tqdm(total=time_steps) as pbar:
294311
for t in range(time_steps):
@@ -304,14 +321,18 @@ def get_trajectory(
304321
w_ = w.detach().clone()
305322
dwdt_ = dwdt.detach().clone()
306323
v = self.vorticity_to_velocity(self.grid, w_)
324+
res = self.residual(w_, dwdt_)
325+
307326
v = torch.stack(v, dim=0)
308327
w_all.append(w_)
309328
v_all.append(v)
310329
dwdt_all.append(dwdt_)
330+
res_all.append(res)
331+
311332
result = {
312333
var_name: torch.stack(var, dim=0)
313334
for var_name, var in zip(
314-
["vorticity", "velocity", "vort_t"], [w_all, v_all, dwdt_all]
335+
["vorticity", "velocity", "vort_t", "residual"], [w_all, v_all, dwdt_all, res_all]
315336
)
316337
}
317338
return result

torch_cfd/forcings.py

Lines changed: 203 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,213 @@
1515
# Modifications copyright (C) 2024 S.Cao
1616
# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717

18-
from typing import Tuple, Optional
18+
from typing import Optional, Tuple
19+
1920
import torch
21+
import torch.nn as nn
22+
2023
from . import grids
2124

2225
Array = torch.Tensor
2326
Grid = grids.Grid
2427
GridArray = grids.GridArray
2528

26-
def kolmogorov_forcing(
27-
grid: Grid,
28-
v: Tuple[Array, Array],
29-
scale: float = 1,
30-
k: int = 2,
31-
swap_xy: bool = False,
32-
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
33-
device: Optional[torch.device] = None,
34-
) -> Array:
35-
"""Returns the Kolmogorov forcing function for turbulence in 2D."""
36-
if offsets is None:
37-
offsets = grid.cell_faces
38-
if grid.device is None and device is not None:
39-
grid.device = device
40-
if swap_xy:
41-
x = grid.mesh(offsets[1])[0]
42-
v = GridArray(scale * torch.sin(k * x), offsets[1], grid)
43-
u = GridArray(torch.zeros_like(v.data), (1, 1 / 2), grid)
44-
f = (u, v)
45-
else:
46-
y = grid.mesh(offsets[0])[1]
47-
u = GridArray(scale * torch.sin(k * y), offsets[0], grid)
48-
v = GridArray(torch.zeros_like(u.data), (1 / 2, 1), grid)
49-
f = (u, v)
50-
return f
29+
30+
class ForcingFn(nn.Module):
31+
"""
32+
A meta class for forcing functions
33+
"""
34+
35+
def __init__(
36+
self,
37+
grid: Grid,
38+
scale: float = 1,
39+
k: int = 1,
40+
diam: float = 1.0,
41+
swap_xy: bool = False,
42+
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
43+
device: Optional[torch.device] = None,
44+
**kwargs,
45+
):
46+
super().__init__()
47+
self.grid = grid
48+
self.scale = scale
49+
self.k = k
50+
self.diam = diam
51+
self.swap_xy = swap_xy
52+
self.offsets = grid.cell_faces if offsets is None else offsets
53+
self.device = grid.device if device is None else device
54+
55+
56+
class KolmogorovForcing(ForcingFn):
57+
"""
58+
The Kolmogorov forcing function used in
59+
Sets up the flow that is used in Kochkov et al. [1].
60+
which is based on Boffetta et al. [2].
61+
62+
Note in the port: this forcing belongs a larger class
63+
of isotropic turbulence. See [3].
64+
65+
References:
66+
[1] Machine learning-accelerated computational fluid dynamics. Dmitrii
67+
Kochkov, Jamie A. Smith, Ayya Alieva, Qing Wang, Michael P. Brenner, Stephan
68+
Hoyer Proceedings of the National Academy of Sciences May 2021, 118 (21)
69+
e2101784118; DOI: 10.1073/pnas.2101784118.
70+
https://doi.org/10.1073/pnas.2101784118
71+
72+
[2] Boffetta, Guido, and Robert E. Ecke. "Two-dimensional turbulence."
73+
Annual review of fluid mechanics 44 (2012): 427-451.
74+
https://doi.org/10.1146/annurev-fluid-120710-101240
75+
76+
[3] McWilliams, J. C. (1984). "The emergence of isolated coherent vortices
77+
in turbulent flow". Journal of Fluid Mechanics, 146, 21-43.
78+
"""
79+
80+
def __init__(
81+
self,
82+
diam=2 * torch.pi,
83+
offsets=((0, 0), (0, 0)),
84+
*args,
85+
**kwargs,
86+
):
87+
super().__init__(
88+
*args,
89+
diam=diam,
90+
offsets=offsets,
91+
**kwargs,
92+
)
93+
94+
def forward(
95+
self,
96+
grid: Optional[Grid],
97+
velocity: Optional[Tuple[Array, Array]] = None,
98+
) -> Tuple[Array, Array]:
99+
offsets = self.offsets
100+
grid = self.grid if grid is None else grid
101+
domain_factor = 2 * torch.pi / self.diam
102+
103+
if self.swap_xy:
104+
x = grid.mesh(offsets[1])[0]
105+
v = GridArray(
106+
self.scale * torch.sin(self.k * domain_factor * x), offsets[1], grid
107+
)
108+
u = GridArray(torch.zeros_like(v.data), (1, 1 / 2), grid)
109+
f = (u, v)
110+
else:
111+
y = grid.mesh(offsets[0])[1]
112+
u = GridArray(
113+
self.scale * torch.sin(self.k * domain_factor * y), offsets[0], grid
114+
)
115+
v = GridArray(torch.zeros_like(u.data), (1 / 2, 1), grid)
116+
f = (u, v)
117+
return f
118+
119+
def potential_template(potential_func):
120+
def wrapper(cls, x: Array, y: Array, s: float, k: float) -> Array:
121+
return potential_func(x, y, s, k)
122+
return wrapper
123+
124+
125+
class SimpleSolenoidalForcing(ForcingFn):
126+
"""
127+
A simple solenoidal (rotating, divergence free) forcing function template.
128+
The template forcing is F = (-psi, psi) such that
129+
130+
Args:
131+
grid: grid on which to simulate the flow
132+
scale: a in the equation above, amplitude of the forcing
133+
k: k in the equation above, wavenumber of the forcing
134+
"""
135+
136+
def __init__(
137+
self,
138+
scale=1,
139+
diam=1.0,
140+
k=1.0,
141+
offsets=((0, 0), (0, 0)),
142+
*args,
143+
**kwargs,
144+
):
145+
super().__init__(
146+
*args,
147+
scale=scale,
148+
diam=diam,
149+
k=k,
150+
offsets=offsets,
151+
**kwargs,
152+
)
153+
154+
155+
@potential_template
156+
def potential(*args, **kwargs) -> Array:
157+
raise NotImplementedError
158+
159+
def forward(
160+
self,
161+
grid: Optional[Grid],
162+
velocity: Optional[Tuple[Array, Array]] = None,
163+
) -> Tuple[Array, Array]:
164+
offsets = self.offsets
165+
grid = self.grid if grid is None else grid
166+
domain_factor = 2 * torch.pi / self.diam
167+
k = self.k * domain_factor
168+
scale = 0.5 * self.scale / (2 * torch.pi) / self.k
169+
170+
if self.swap_xy:
171+
x = grid.mesh(offsets[1])[0]
172+
y = grid.mesh(offsets[0])[1]
173+
rot = self.potential(x, y, scale, k)
174+
v = GridArray(rot, offsets[1], grid)
175+
u = GridArray(-rot, (1, 1 / 2), grid)
176+
f = (u, v)
177+
else:
178+
x = grid.mesh(offsets[0])[0]
179+
y = grid.mesh(offsets[1])[1]
180+
rot = self.potential(x, y, scale, k)
181+
u = GridArray(rot, offsets[0], grid)
182+
v = GridArray(-rot, (1 / 2, 1), grid)
183+
f = (u, v)
184+
return f
185+
186+
187+
class SinCosForcing(SimpleSolenoidalForcing):
188+
"""
189+
The solenoidal (divergence free) forcing function used in [4].
190+
191+
Note: in the vorticity-streamfunction formulation, the forcing
192+
is actually the curl of the velocity field, which
193+
is a*(sin(2*pi*k*(x+y)) + cos(2*pi*k*(x+y)))
194+
a=0.1, k=1 in [4]
195+
196+
References:
197+
[4] Li, Zongyi, et al. "Fourier Neural Operator for
198+
Parametric Partial Differential Equations."
199+
ICLR. 2020.
200+
201+
Args:
202+
grid: grid on which to simulate the flow
203+
scale: a in the equation above, amplitude of the forcing
204+
k: k in the equation above, wavenumber of the forcing
205+
"""
206+
207+
def __init__(
208+
self,
209+
scale=0.1,
210+
diam=1.0,
211+
k=1.0,
212+
offsets=((0, 0), (0, 0)),
213+
*args,
214+
**kwargs,
215+
):
216+
super().__init__(
217+
*args,
218+
scale=scale,
219+
diam=diam,
220+
k=k,
221+
offsets=offsets,
222+
**kwargs,
223+
)
224+
225+
@potential_template
226+
def potential(x: Array, y: Array, s: float, k: float) -> Array:
227+
return s * (torch.sin(k * (x + y)) - torch.cos(k * (x + y)))

0 commit comments

Comments
 (0)