Skip to content

Commit e898a2a

Browse files
committed
refactor RK4 stepper to nn.Module with params
1 parent dccb7b7 commit e898a2a

File tree

3 files changed

+66
-62
lines changed

3 files changed

+66
-62
lines changed

fno/data_gen/data_gen_Kolmogorov2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def main(args):
120120
drag=0.1,
121121
smooth=True,
122122
forcing_fn=forcing_fn,
123-
solver=rk4_crank_nicolson,
123+
solver=RK4CrankNicholson,
124124
).to(device)
125125

126126
num_batches = total_samples // batch_size

fno/data_gen/data_gen_McWilliams2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main(args):
117117
drag=0,
118118
smooth=True,
119119
forcing_fn=None,
120-
solver=rk4_crank_nicolson,
120+
solver=RK4CrankNicholson,
121121
).to(device)
122122

123123
num_batches = total_samples // batch_size

torch_cfd/equations.py

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
Array = torch.Tensor
2929
Grid = grids.Grid
30+
Params = Union[nn.ParameterDict, Dict]
3031

3132

3233
def fft_mesh_2d(n, diam, device=None):
@@ -277,15 +278,11 @@ def rk2_crank_nicolson(
277278
return u
278279

279280

280-
def low_storage_runge_kutta_crank_nicolson(
281-
u: torch.Tensor,
282-
dt: float,
283-
params: Dict,
284-
equation: ImplicitExplicitODE,
285-
) -> Array:
281+
class RK4CrankNicholson(nn.Module):
286282
"""
287-
ported from jax functional programming to be tensor2tensor
283+
RK4CrankNicholson is ported from jax functional programming to follow the standard tensor2tensor format of nn.Module
288284
Time stepping via "low-storage" Runge-Kutta and Crank-Nicolson steps.
285+
https://github.com/google/jax-cfd/blob/main/jax_cfd/spectral/time_stepping.py#L117
289286
290287
These scheme are second order accurate for the implicit terms, but potentially
291288
higher order accurate for the explicit terms. This seems to be a favorable
@@ -304,64 +301,69 @@ def low_storage_runge_kutta_crank_nicolson(
304301
equation.implicit_solve: implicit solver, when evaluates at an input (B, n, n), outputs (B, n, n).
305302
dt: time step.
306303
307-
Input: w^{t_i} (B, n, n)
308-
Returns: w^{t_{i+1}} (B, n, n)
309-
310304
Reference:
311305
Canuto, C., Yousuff Hussaini, M., Quarteroni, A. & Zang, T. A.
312306
Spectral Methods: Evolution to Complex Geometries and Applications to
313307
Fluid Dynamics. (Springer Berlin Heidelberg, 2007).
314308
https://doi.org/10.1007/978-3-540-30728-0 (Appendix D.3)
315309
"""
316-
dt = dt
317-
alphas = params["alphas"]
318-
betas = params["betas"]
319-
gammas = params["gammas"]
320-
F = equation.explicit_terms
321-
G = equation.implicit_terms
322-
G_inv = equation.implicit_solve
323-
324-
if len(alphas) - 1 != len(betas) != len(gammas):
325-
raise ValueError("number of RK coefficients does not match")
326-
327-
h = 0
328-
for k in range(len(betas)):
329-
h = F(u) + betas[k] * h
330-
mu = 0.5 * dt * (alphas[k + 1] - alphas[k])
331-
u = G_inv(u + gammas[k] * dt * h + mu * G(u), mu)
332-
return u
310+
def __init__(self,
311+
requires_grad: bool = False,
312+
*args,
313+
**kwargs):
314+
super().__init__(*args, **kwargs)
315+
self.params = nn.ParameterDict(
316+
{'alphas': nn.Parameter(torch.tensor([
317+
0,
318+
0.1496590219993,
319+
0.3704009573644,
320+
0.6222557631345,
321+
0.9582821306748,
322+
1,
323+
])),
324+
'betas': nn.Parameter(torch.tensor([0, -0.4178904745, -1.192151694643, -1.697784692471, -1.514183444257])),
325+
'gammas': nn.Parameter(torch.tensor([
326+
0.1496590219993,
327+
0.3792103129999,
328+
0.8229550293869,
329+
0.6994504559488,
330+
0.1530572479681,
331+
]))})
332+
if not requires_grad:
333+
for k, v in self.params.items():
334+
v.requires_grad = False
335+
336+
def forward(
337+
self,
338+
u: Array,
339+
dt: float,
340+
equation: ImplicitExplicitODE,
341+
params: Optional[Params] = None,
342+
) -> Array:
343+
"""
344+
Input:
345+
- w^{t_i} (B, n, n)
346+
- dt: time step
347+
- params: RK coefficients optional to override
348+
Returns: w^{t_{i+1}} (B, n, n)
349+
"""
350+
params = self.params if params is None else params
351+
alphas = params["alphas"]
352+
betas = params["betas"]
353+
gammas = params["gammas"]
354+
F = equation.explicit_terms
355+
G = equation.implicit_terms
356+
G_inv = equation.implicit_solve
333357

358+
if len(alphas) - 1 != len(betas) != len(gammas):
359+
raise ValueError("number of RK coefficients does not match")
334360

335-
def rk4_crank_nicolson(
336-
u: Array,
337-
dt: float,
338-
equation: ImplicitExplicitODE,
339-
) -> Array:
340-
"""Time stepping via Crank-Nicolson and RK4 ("Carpenter-Kennedy")."""
341-
params = dict(
342-
alphas=[
343-
0,
344-
0.1496590219993,
345-
0.3704009573644,
346-
0.6222557631345,
347-
0.9582821306748,
348-
1,
349-
],
350-
betas=[0, -0.4178904745, -1.192151694643, -1.697784692471, -1.514183444257],
351-
gammas=[
352-
0.1496590219993,
353-
0.3792103129999,
354-
0.8229550293869,
355-
0.6994504559488,
356-
0.1530572479681,
357-
],
358-
)
359-
return low_storage_runge_kutta_crank_nicolson(
360-
u,
361-
dt=dt,
362-
equation=equation,
363-
params=params,
364-
)
361+
h = 0
362+
for k in range(len(betas)):
363+
h = F(u) + betas[k] * h
364+
mu = 0.5 * dt * (alphas[k + 1] - alphas[k])
365+
u = G_inv(u + gammas[k] * dt * h + mu * G(u), mu)
366+
return u
365367

366368

367369
class NavierStokes2DSpectral(ImplicitExplicitODE):
@@ -385,15 +387,17 @@ def __init__(
385387
drag: float = 0.0,
386388
smooth: bool = True,
387389
forcing_fn: Optional[Callable] = None,
388-
solver: Optional[Callable] = rk4_crank_nicolson,
390+
solver: Optional[Callable] = RK4CrankNicholson,
391+
requires_grad: bool = False,
392+
**solver_kwargs,
389393
):
390394
super().__init__()
391395
self.viscosity = viscosity
392396
self.grid = grid
393397
self.drag = drag
394398
self.smooth = smooth
395399
self.forcing_fn = forcing_fn
396-
self.solver = solver
400+
self.solver = solver(requires_grad=requires_grad, **solver_kwargs)
397401
self._initialize()
398402

399403
def _initialize(self):
@@ -457,7 +461,7 @@ def step(self, *args, **kwargs):
457461
def forward(self, vort_hat, dt, steps=1):
458462
"""
459463
vort_hat: (B, kx, ky) or (n_t, kx, ky) or (kx, ky)
460-
- if rfft2 is used then the shape is (*, kx, ky//2+1)
464+
- if rfft2 is used then the shape is (*, nx, ny//2+1)
461465
- if (n_t, kx, ky), then the time step marches in the time
462466
dimension in parallel.
463467
"""

0 commit comments

Comments
 (0)