Skip to content

Commit d1cc89c

Browse files
committed
Add vorticity to forcing and initial conditions directly
1 parent 8dbb1d5 commit d1cc89c

File tree

4 files changed

+377
-69
lines changed

4 files changed

+377
-69
lines changed

example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Lines changed: 60 additions & 42 deletions
Large diffs are not rendered by default.

torch_cfd/equations.py

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,44 @@
2323
from . import grids
2424
from tqdm.auto import tqdm
2525

26-
TQDM_ITERS = 100
26+
TQDM_ITERS = 500
2727

2828
Array = torch.Tensor
2929
Grid = grids.Grid
3030

3131

32+
def stable_time_step(
33+
dx: float = None,
34+
dt: float = None,
35+
max_velocity: float = 1.0,
36+
max_courant_number: float = 0.5,
37+
viscosity: float = 1e-3,
38+
implicit_diffusion: bool = True,
39+
ndim: int = 2,
40+
) -> float:
41+
"""
42+
Calculate a stable time step satisfying the CFL condition
43+
for the explicit advection term
44+
if the diffusion is explicit, the time step is the smaller
45+
of the advection and diffusion time steps.
46+
47+
Args:
48+
max_velocity: maximum velocity.
49+
max_courant_number: the Courant number used to choose the time step. Smaller
50+
numbers will lead to more stable simulations. Typically this should be in
51+
the range [0.5, 1).
52+
dx: spatial mesh size, can be min(grid.step).
53+
dt: time step.
54+
"""
55+
dt_diffusion = dx
56+
57+
if not implicit_diffusion:
58+
dt_diffusion = dx ** 2 / (viscosity * 2 ** (ndim))
59+
dt_advection = max_courant_number * dx / max_velocity
60+
dt = dt_advection if dt is None else dt
61+
return min(dt_diffusion, dt_advection, dt)
62+
63+
3264
class ImplicitExplicitODE(nn.Module):
3365
"""Describes a set of ODEs with implicit & explicit terms.
3466
@@ -61,6 +93,90 @@ def implicit_solve(
6193
raise NotImplementedError
6294

6395

96+
def backward_forward_euler(
97+
u: torch.Tensor,
98+
dt: float,
99+
equation: ImplicitExplicitODE,
100+
) -> Array:
101+
"""Time stepping via forward and backward Euler methods.
102+
103+
This method is first order accurate.
104+
105+
Args:
106+
equation: equation to solve.
107+
time_step: time step.
108+
109+
Returns:
110+
Function that performs a time step.
111+
"""
112+
F = equation.explicit_terms
113+
G_inv = equation.implicit_solve
114+
115+
g = u + dt * F(u)
116+
u = G_inv(g, dt)
117+
118+
return u
119+
120+
121+
def imex_crank_nicolson(
122+
u: torch.Tensor,
123+
dt: float,
124+
equation: ImplicitExplicitODE,
125+
) -> Array:
126+
"""Time stepping via forward and backward Euler methods.
127+
128+
This method is first order accurate.
129+
130+
Args:
131+
equation: equation to solve.
132+
time_step: time step.
133+
134+
Returns:
135+
Function that performs a time step.
136+
"""
137+
F = equation.explicit_terms
138+
G = equation.implicit_terms
139+
G_inv = equation.implicit_solve
140+
141+
g = u + dt * F(u) + 0.5 * dt * G(u)
142+
u = G_inv(g, 0.5 * dt)
143+
144+
return u
145+
146+
147+
def rk2_crank_nicolson(
148+
u: torch.Tensor,
149+
dt: float,
150+
equation: ImplicitExplicitODE,
151+
) -> Array:
152+
"""Time stepping via Crank-Nicolson and 2nd order Runge-Kutta (Heun).
153+
154+
This method is second order accurate.
155+
156+
Args:
157+
equation: equation to solve.
158+
time_step: time step.
159+
160+
Returns:
161+
Function that performs a time step.
162+
163+
Reference:
164+
Chandler, G. J. & Kerswell, R. R. Invariant recurrent solutions embedded in
165+
a turbulent two-dimensional Kolmogorov flow. J. Fluid Mech. 722, 554–595
166+
(2013). https://doi.org/10.1017/jfm.2013.122 (Section 3)
167+
"""
168+
F = equation.explicit_terms
169+
G = equation.implicit_terms
170+
G_inv = equation.implicit_solve
171+
172+
g = u + 0.5 * dt * G(u)
173+
h = F(u)
174+
u = G_inv(g + dt * h, 0.5 * dt)
175+
h = 0.5 * (F(u) + h)
176+
u = G_inv(g + dt * h, 0.5 * dt)
177+
return u
178+
179+
64180
def low_storage_runge_kutta_crank_nicolson(
65181
u: torch.Tensor,
66182
dt: float,
@@ -143,8 +259,8 @@ def crank_nicolson_rk4(
143259
return low_storage_runge_kutta_crank_nicolson(
144260
u,
145261
dt=dt,
146-
params=params,
147262
equation=equation,
263+
params=params,
148264
)
149265

150266

@@ -247,16 +363,23 @@ def vorticity_to_velocity(
247363
vxhat = two_pi_i * ky * psi_hat
248364
vyhat = -two_pi_i * kx * psi_hat
249365
return vxhat, vyhat
250-
251-
def residual(self,
366+
367+
def residual(
368+
self,
252369
vort_hat: Array,
253370
vort_t_hat: Array,
254371
):
255-
residual = vort_t_hat - self.explicit_terms(vort_hat) - self.viscosity * self.implicit_terms(vort_hat)
372+
residual = (
373+
vort_t_hat
374+
- self.explicit_terms(vort_hat)
375+
- self.viscosity * self.implicit_terms(vort_hat)
376+
)
256377
return residual
257378

258379
def _explicit_terms(self, vort_hat):
259-
vxhat, vyhat = self.vorticity_to_velocity(self.grid, vort_hat, (self.kx, self.ky))
380+
vxhat, vyhat = self.vorticity_to_velocity(
381+
self.grid, vort_hat, (self.kx, self.ky)
382+
)
260383
vx, vy = fft.irfft2(vxhat), fft.irfft2(vyhat)
261384

262385
grad_x_hat = 2j * torch.pi * self.kx * vort_hat
@@ -272,10 +395,14 @@ def _explicit_terms(self, vort_hat):
272395
terms = advection_hat
273396

274397
if self.forcing_fn is not None:
275-
fx, fy = self.forcing_fn(self.grid, (vx, vy))
276-
fx_hat, fy_hat = fft.rfft2(fx.data), fft.rfft2(fy.data)
277-
terms += self.spectral_curl_2d((fx_hat, fy_hat), (self.kx, self.ky))
278-
398+
if not self.forcing_fn.vorticity:
399+
fx, fy = self.forcing_fn(self.grid, (vx, vy))
400+
fx_hat, fy_hat = fft.rfft2(fx.data), fft.rfft2(fy.data)
401+
terms += self.spectral_curl_2d((fx_hat, fy_hat), (self.kx, self.ky))
402+
else:
403+
f = self.forcing_fn(self.grid, vort_hat)
404+
f_hat = fft.rfft2(f.data)
405+
terms += f_hat
279406
return terms
280407

281408
def explicit_terms(self, vort_hat):
@@ -332,7 +459,8 @@ def get_trajectory(
332459
result = {
333460
var_name: torch.stack(var, dim=0)
334461
for var_name, var in zip(
335-
["vorticity", "velocity", "vort_t", "residual"], [w_all, v_all, dwdt_all, res_all]
462+
["vorticity", "velocity", "vort_t", "residual"],
463+
[w_all, v_all, dwdt_all, res_all],
336464
)
337465
}
338466
return result

0 commit comments

Comments
 (0)