Skip to content

Commit 2f5dea8

Browse files
committed
added batch dimension in solver class
1 parent c5f842b commit 2f5dea8

File tree

4 files changed

+378
-199
lines changed

4 files changed

+378
-199
lines changed

example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Lines changed: 191 additions & 83 deletions
Large diffs are not rendered by default.

torch_cfd/equations.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
# Modifications copyright (C) 2024 S.Cao
1616
# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717

18-
from typing import Tuple, Callable, Dict, Optional
18+
from typing import Callable, Dict, Optional, Tuple
1919

2020
import torch
21-
import torch.nn as nn
2221
import torch.fft as fft
23-
from . import grids
22+
import torch.nn as nn
2423
from tqdm.auto import tqdm
2524

25+
from . import grids
26+
2627
TQDM_ITERS = 500
2728

2829
Array = torch.Tensor
@@ -55,7 +56,7 @@ def stable_time_step(
5556
dt_diffusion = dx
5657

5758
if not implicit_diffusion:
58-
dt_diffusion = dx ** 2 / (viscosity * 2 ** (ndim))
59+
dt_diffusion = dx**2 / (viscosity * 2 ** (ndim))
5960
dt_advection = max_courant_number * dx / max_velocity
6061
dt = dt_advection if dt is None else dt
6162
return min(dt_diffusion, dt_advection, dt)
@@ -264,7 +265,7 @@ def crank_nicolson_rk4(
264265
)
265266

266267

267-
class NavierStokes2DSpectral(nn.Module):
268+
class NavierStokes2DSpectral(ImplicitExplicitODE):
268269
"""Breaks the Navier-Stokes equation into implicit and explicit parts.
269270
270271
Implicit parts are the linear terms and explicit parts are the non-linear
@@ -357,12 +358,13 @@ def vorticity_to_velocity(
357358
"""
358359
kx, ky = rfft_mesh if rfft_mesh is not None else grid.rfft_mesh()
359360
two_pi_i = 2 * torch.pi * 1j
361+
assert kx.shape[-2:] == w_hat.shape[-2:]
360362
laplace = two_pi_i**2 * (abs(kx) ** 2 + abs(ky) ** 2)
361-
laplace[0, 0] = 1
363+
laplace[..., 0, 0] = 1
362364
psi_hat = -1 / laplace * w_hat
363-
vxhat = two_pi_i * ky * psi_hat
364-
vyhat = -two_pi_i * kx * psi_hat
365-
return vxhat, vyhat
365+
u_hat = two_pi_i * ky * psi_hat
366+
v_hat = -two_pi_i * kx * psi_hat
367+
return u_hat, v_hat
366368

367369
def residual(
368370
self,
@@ -426,8 +428,15 @@ def get_trajectory(
426428
):
427429
"""
428430
vorticity stacked in the time dimension
431+
all inputs and outputs are in the frequency domain
432+
input: w0 (*, n, n)
433+
output:
434+
435+
vorticity (*, n_t, kx, ky)
436+
velocity: tuple of (*, n_t, kx, ky)
429437
"""
430438
w_all = []
439+
u_all = []
431440
v_all = []
432441
dwdt_all = []
433442
res_all = []
@@ -445,30 +454,40 @@ def get_trajectory(
445454
pbar.update(update_iters)
446455

447456
if t % record_every_steps == 0:
448-
w_ = w.detach().clone()
449-
dwdt_ = dwdt.detach().clone()
450-
v = self.vorticity_to_velocity(self.grid, w_)
451-
res = self.residual(w_, dwdt_)
457+
u, v = self.vorticity_to_velocity(self.grid, w)
458+
res = self.residual(w, dwdt)
459+
460+
w_, dwdt_, u, v, res = [
461+
var.detach().cpu().clone() for var in [w, dwdt, u, v, res]
462+
]
452463

453-
v = torch.stack(v, dim=0)
454464
w_all.append(w_)
465+
u_all.append(u)
455466
v_all.append(v)
456467
dwdt_all.append(dwdt_)
457468
res_all.append(res)
458469

459470
result = {
460-
var_name: torch.stack(var, dim=0)
471+
var_name: torch.stack(var, dim=-3)
461472
for var_name, var in zip(
462-
["vorticity", "velocity", "vort_t", "residual"],
463-
[w_all, v_all, dwdt_all, res_all],
473+
["vorticity", "u", "v", "vort_t", "residual"],
474+
[w_all, u_all, v_all, dwdt_all, res_all],
464475
)
465476
}
466477
return result
467478

468479
def step(self, *args, **kwargs):
469480
return self.forward(*args, **kwargs)
470481

471-
def forward(self, vort_hat, dt):
472-
vort_hat_new = self.solver(vort_hat, dt, self)
473-
dvortdt_hat = 1 / dt * (vort_hat_new - vort_hat)
474-
return vort_hat_new, dvortdt_hat
482+
def forward(self, vort_hat, dt, steps=1):
483+
"""
484+
vort_hat: (B, kx, ky) or (n_t, kx, ky) or (kx, ky)
485+
- if rfft2 is used then the shape is (*, kx, ky//2+1)
486+
- if (n_t, kx, ky), then the time step marches in the time
487+
dimension in parallel.
488+
"""
489+
vort_old = vort_hat
490+
for _ in range(steps):
491+
vort_hat = self.solver(vort_hat, dt, self)
492+
dvortdt_hat = 1 / (steps * dt) * (vort_hat - vort_old)
493+
return vort_hat, dvortdt_hat

torch_cfd/initial_conditions.py

Lines changed: 16 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717

1818
"""Prepare initial conditions for simulations."""
19-
from typing import Callable, Optional, Sequence
2019
import math
20+
from typing import Callable, Optional, Sequence
21+
2122
import torch
2223
import torch.fft as fft
23-
from . import grids
24-
from . import finite_differences as fd
25-
from . import fast_diagonalization as solver
24+
25+
from . import grids, pressure
2626

2727
Array = torch.Tensor
2828
GridArray = grids.GridArray
@@ -45,6 +45,7 @@ def wrap_velocities(
4545
for u, offset, bc in zip(v, grid.cell_faces, bcs)
4646
)
4747

48+
4849
def wrap_vorticity(
4950
w: Array,
5051
grid: grids.Grid,
@@ -57,7 +58,11 @@ def wrap_vorticity(
5758

5859

5960
def _log_normal_density(k, mode: float, variance=0.25):
60-
"""Unscaled PDF for a log normal given `mode` and log variance 1."""
61+
"""
62+
Unscaled PDF for a log normal given `mode` and log variance 1.
63+
64+
65+
"""
6166
mean = math.log(mode) + variance
6267
logk = torch.log(k)
6368
return torch.exp(-((mean - logk) ** 2) / 2 / variance - logk)
@@ -74,6 +79,7 @@ def McWilliams_density(k, mode: float, tau: float = 1.0):
7479
"""
7580
return (k * (tau**2 + (k / mode) ** 4)) ** (-1)
7681

82+
7783
def _angular_frequency_magnitude(grid: grids.Grid) -> Array:
7884
frequencies = [
7985
2 * torch.pi * fft.fftfreq(size, step)
@@ -95,103 +101,19 @@ def spectral_filter(
95101
# real, because our spectral density only depends on norm(k).
96102
return fft.ifftn(fft.fftn(v) * filters).real
97103

104+
98105
def streamfunc_normalize(k, psi):
99-
# only half the spectrum for real ffts, needs spectral normalisation
100106
nx, ny = psi.shape
101107
psih = fft.fft2(psi)
102-
uh = k * psih
103-
kinetic_energy = (2 * uh.abs() ** 2 / (nx * ny) ** 2).sum()
108+
uh_mag = k * psih
109+
kinetic_energy = (2 * uh_mag.abs() ** 2 / (nx * ny) ** 2).sum()
104110
return psi / kinetic_energy.sqrt()
105111

106-
def _rhs_transform(
107-
u: GridArray,
108-
bc: BoundaryConditions,
109-
) -> Array:
110-
"""Transform the RHS of pressure projection equation for stability.
111-
112-
In case of poisson equation, the kernel is subtracted from RHS for stability.
113-
114-
Args:
115-
u: a GridArray that solves ∇²x = u.
116-
bc: specifies boundary of x.
117-
118-
Returns:
119-
u' s.t. u = u' + kernel of the laplacian.
120-
"""
121-
u_data = u.data
122-
for axis in range(u.grid.ndim):
123-
if (
124-
bc.types[axis][0] == grids.BCType.NEUMANN
125-
and bc.types[axis][1] == grids.BCType.NEUMANN
126-
):
127-
# if all sides are neumann, poisson solution has a kernel of constant
128-
# functions. We substact the mean to ensure consistency.
129-
u_data = u_data - torch.mean(u_data)
130-
return u_data
131-
132-
133-
def solve_fast_diag(
134-
v: GridVariableVector,
135-
q0: Optional[GridVariable] = None,
136-
pressure_bc: Optional[grids.ConstantBoundaryConditions] = None,
137-
implementation: Optional[str] = None,
138-
) -> GridArray:
139-
"""Solve for pressure using the fast diagonalization approach."""
140-
del q0 # unused
141-
if pressure_bc is None:
142-
pressure_bc = grids.get_pressure_bc_from_velocity(v)
143-
if grids.has_all_periodic_boundary_conditions(*v):
144-
circulant = True
145-
else:
146-
circulant = False
147-
# only matmul implementation supports non-circulant matrices
148-
implementation = "matmul"
149-
grid = grids.consistent_grid(*v)
150-
rhs = fd.divergence(v)
151-
laplacians = list(map(fd.laplacian_matrix, grid.shape, grid.step))
152-
laplacians = [lap.to(grid.device) for lap in laplacians]
153-
rhs_transformed = _rhs_transform(rhs, pressure_bc)
154-
pinv = solver.pseudoinverse(
155-
rhs_transformed,
156-
laplacians,
157-
rhs_transformed.dtype,
158-
hermitian=True,
159-
circulant=circulant,
160-
implementation=implementation,
161-
)
162-
# return applied(pinv)(rhs_transformed)
163-
return GridArray(pinv, rhs.offset, rhs.grid)
164-
165-
166-
def projection(
167-
v: GridVariableVector,
168-
solve: Callable = solve_fast_diag,
169-
) -> GridVariableVector:
170-
"""
171-
Apply pressure projection (a discrete Helmholtz decomposition)
172-
to make a velocity field divergence free.
173-
174-
Note by S.Cao: this was originally implemented by the jax-cfd team
175-
but using FDM results having a non-negligible error in fp32.
176-
One resolution is to use fp64 then cast back to fp32.
177-
"""
178-
grid = grids.consistent_grid(*v)
179-
pressure_bc = grids.get_pressure_bc_from_velocity(v)
180-
181-
q0 = GridArray(torch.zeros(grid.shape), grid.cell_center, grid)
182-
q0 = pressure_bc.impose_bc(q0)
183-
184-
q = solve(v, q0, pressure_bc)
185-
q = pressure_bc.impose_bc(q)
186-
q_grad = fd.forward_difference(q)
187-
v_projected = tuple(u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad))
188-
return v_projected
189-
190112

191113
def project_and_normalize(
192114
v: GridVariableVector, maximum_velocity: float = 1
193115
) -> GridVariableVector:
194-
v = projection(v)
116+
v = pressure.projection(v)
195117
vmax = torch.linalg.norm(torch.stack([u.data for u in v]), dim=0).max()
196118
v = tuple(GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
197119
return v
@@ -256,7 +178,6 @@ def vorticity_field(
256178
Args:
257179
rng_key: key for seeding the random initial vorticity field.
258180
grid: the grid on which the vorticity field is defined.
259-
maximum_velocity: the maximum speed in the velocity field.
260181
peak_wavenumber: the velocity field will be filtered so that the largest
261182
magnitudes are associated with this wavenumber.
262183
@@ -277,4 +198,4 @@ def spectral_density(k):
277198
boundary_condition = grids.periodic_boundary_conditions(grid.ndim)
278199
vorticity = wrap_vorticity(vorticity, grid, boundary_condition)
279200

280-
return vorticity
201+
return vorticity

0 commit comments

Comments
 (0)