2727
2828Array = torch .Tensor
2929Grid = grids .Grid
30+ Params = Union [nn .ParameterDict , Dict ]
3031
3132
3233def fft_mesh_2d (n , diam , device = None ):
@@ -277,15 +278,11 @@ def rk2_crank_nicolson(
277278 return u
278279
279280
280- def low_storage_runge_kutta_crank_nicolson (
281- u : torch .Tensor ,
282- dt : float ,
283- params : Dict ,
284- equation : ImplicitExplicitODE ,
285- ) -> Array :
281+ class RK4CrankNicholson (nn .Module ):
286282 """
287- ported from jax functional programming to be tensor2tensor
283+ RK4CrankNicholson is ported from jax functional programming to follow the standard tensor2tensor format of nn.Module
288284 Time stepping via "low-storage" Runge-Kutta and Crank-Nicolson steps.
285+ https://github.com/google/jax-cfd/blob/main/jax_cfd/spectral/time_stepping.py#L117
289286
290287 These scheme are second order accurate for the implicit terms, but potentially
291288 higher order accurate for the explicit terms. This seems to be a favorable
@@ -304,64 +301,69 @@ def low_storage_runge_kutta_crank_nicolson(
304301 equation.implicit_solve: implicit solver, when evaluates at an input (B, n, n), outputs (B, n, n).
305302 dt: time step.
306303
307- Input: w^{t_i} (B, n, n)
308- Returns: w^{t_{i+1}} (B, n, n)
309-
310304 Reference:
311305 Canuto, C., Yousuff Hussaini, M., Quarteroni, A. & Zang, T. A.
312306 Spectral Methods: Evolution to Complex Geometries and Applications to
313307 Fluid Dynamics. (Springer Berlin Heidelberg, 2007).
314308 https://doi.org/10.1007/978-3-540-30728-0 (Appendix D.3)
315309 """
316- dt = dt
317- alphas = params ["alphas" ]
318- betas = params ["betas" ]
319- gammas = params ["gammas" ]
320- F = equation .explicit_terms
321- G = equation .implicit_terms
322- G_inv = equation .implicit_solve
323-
324- if len (alphas ) - 1 != len (betas ) != len (gammas ):
325- raise ValueError ("number of RK coefficients does not match" )
326-
327- h = 0
328- for k in range (len (betas )):
329- h = F (u ) + betas [k ] * h
330- mu = 0.5 * dt * (alphas [k + 1 ] - alphas [k ])
331- u = G_inv (u + gammas [k ] * dt * h + mu * G (u ), mu )
332- return u
310+ def __init__ (self ,
311+ requires_grad : bool = False ,
312+ * args ,
313+ ** kwargs ):
314+ super ().__init__ (* args , ** kwargs )
315+ self .params = nn .ParameterDict (
316+ {'alphas' : nn .Parameter (torch .tensor ([
317+ 0 ,
318+ 0.1496590219993 ,
319+ 0.3704009573644 ,
320+ 0.6222557631345 ,
321+ 0.9582821306748 ,
322+ 1 ,
323+ ])),
324+ 'betas' : nn .Parameter (torch .tensor ([0 , - 0.4178904745 , - 1.192151694643 , - 1.697784692471 , - 1.514183444257 ])),
325+ 'gammas' : nn .Parameter (torch .tensor ([
326+ 0.1496590219993 ,
327+ 0.3792103129999 ,
328+ 0.8229550293869 ,
329+ 0.6994504559488 ,
330+ 0.1530572479681 ,
331+ ]))})
332+ if not requires_grad :
333+ for k , v in self .params .items ():
334+ v .requires_grad = False
335+
336+ def forward (
337+ self ,
338+ u : Array ,
339+ dt : float ,
340+ equation : ImplicitExplicitODE ,
341+ params : Optional [Params ] = None ,
342+ ) -> Array :
343+ """
344+ Input:
345+ - w^{t_i} (B, n, n)
346+ - dt: time step
347+ - params: RK coefficients optional to override
348+ Returns: w^{t_{i+1}} (B, n, n)
349+ """
350+ params = self .params if params is None else params
351+ alphas = params ["alphas" ]
352+ betas = params ["betas" ]
353+ gammas = params ["gammas" ]
354+ F = equation .explicit_terms
355+ G = equation .implicit_terms
356+ G_inv = equation .implicit_solve
333357
358+ if len (alphas ) - 1 != len (betas ) != len (gammas ):
359+ raise ValueError ("number of RK coefficients does not match" )
334360
335- def rk4_crank_nicolson (
336- u : Array ,
337- dt : float ,
338- equation : ImplicitExplicitODE ,
339- ) -> Array :
340- """Time stepping via Crank-Nicolson and RK4 ("Carpenter-Kennedy")."""
341- params = dict (
342- alphas = [
343- 0 ,
344- 0.1496590219993 ,
345- 0.3704009573644 ,
346- 0.6222557631345 ,
347- 0.9582821306748 ,
348- 1 ,
349- ],
350- betas = [0 , - 0.4178904745 , - 1.192151694643 , - 1.697784692471 , - 1.514183444257 ],
351- gammas = [
352- 0.1496590219993 ,
353- 0.3792103129999 ,
354- 0.8229550293869 ,
355- 0.6994504559488 ,
356- 0.1530572479681 ,
357- ],
358- )
359- return low_storage_runge_kutta_crank_nicolson (
360- u ,
361- dt = dt ,
362- equation = equation ,
363- params = params ,
364- )
361+ h = 0
362+ for k in range (len (betas )):
363+ h = F (u ) + betas [k ] * h
364+ mu = 0.5 * dt * (alphas [k + 1 ] - alphas [k ])
365+ u = G_inv (u + gammas [k ] * dt * h + mu * G (u ), mu )
366+ return u
365367
366368
367369class NavierStokes2DSpectral (ImplicitExplicitODE ):
@@ -385,15 +387,17 @@ def __init__(
385387 drag : float = 0.0 ,
386388 smooth : bool = True ,
387389 forcing_fn : Optional [Callable ] = None ,
388- solver : Optional [Callable ] = rk4_crank_nicolson ,
390+ solver : Optional [Callable ] = RK4CrankNicholson ,
391+ requires_grad : bool = False ,
392+ ** solver_kwargs ,
389393 ):
390394 super ().__init__ ()
391395 self .viscosity = viscosity
392396 self .grid = grid
393397 self .drag = drag
394398 self .smooth = smooth
395399 self .forcing_fn = forcing_fn
396- self .solver = solver
400+ self .solver = solver ( requires_grad = requires_grad , ** solver_kwargs )
397401 self ._initialize ()
398402
399403 def _initialize (self ):
@@ -457,7 +461,7 @@ def step(self, *args, **kwargs):
457461 def forward (self , vort_hat , dt , steps = 1 ):
458462 """
459463 vort_hat: (B, kx, ky) or (n_t, kx, ky) or (kx, ky)
460- - if rfft2 is used then the shape is (*, kx, ky //2+1)
464+ - if rfft2 is used then the shape is (*, nx, ny //2+1)
461465 - if (n_t, kx, ky), then the time step marches in the time
462466 dimension in parallel.
463467 """
0 commit comments