2323from . import grids
2424from tqdm .auto import tqdm
2525
26- TQDM_ITERS = 100
26+ TQDM_ITERS = 500
2727
2828Array = torch .Tensor
2929Grid = grids .Grid
3030
3131
32+ def stable_time_step (
33+ dx : float = None ,
34+ dt : float = None ,
35+ max_velocity : float = 1.0 ,
36+ max_courant_number : float = 0.5 ,
37+ viscosity : float = 1e-3 ,
38+ implicit_diffusion : bool = True ,
39+ ndim : int = 2 ,
40+ ) -> float :
41+ """
42+ Calculate a stable time step satisfying the CFL condition
43+ for the explicit advection term
44+ if the diffusion is explicit, the time step is the smaller
45+ of the advection and diffusion time steps.
46+
47+ Args:
48+ max_velocity: maximum velocity.
49+ max_courant_number: the Courant number used to choose the time step. Smaller
50+ numbers will lead to more stable simulations. Typically this should be in
51+ the range [0.5, 1).
52+ dx: spatial mesh size, can be min(grid.step).
53+ dt: time step.
54+ """
55+ dt_diffusion = dx
56+
57+ if not implicit_diffusion :
58+ dt_diffusion = dx ** 2 / (viscosity * 2 ** (ndim ))
59+ dt_advection = max_courant_number * dx / max_velocity
60+ dt = dt_advection if dt is None else dt
61+ return min (dt_diffusion , dt_advection , dt )
62+
63+
3264class ImplicitExplicitODE (nn .Module ):
3365 """Describes a set of ODEs with implicit & explicit terms.
3466
@@ -61,6 +93,90 @@ def implicit_solve(
6193 raise NotImplementedError
6294
6395
96+ def backward_forward_euler (
97+ u : torch .Tensor ,
98+ dt : float ,
99+ equation : ImplicitExplicitODE ,
100+ ) -> Array :
101+ """Time stepping via forward and backward Euler methods.
102+
103+ This method is first order accurate.
104+
105+ Args:
106+ equation: equation to solve.
107+ time_step: time step.
108+
109+ Returns:
110+ Function that performs a time step.
111+ """
112+ F = equation .explicit_terms
113+ G_inv = equation .implicit_solve
114+
115+ g = u + dt * F (u )
116+ u = G_inv (g , dt )
117+
118+ return u
119+
120+
121+ def imex_crank_nicolson (
122+ u : torch .Tensor ,
123+ dt : float ,
124+ equation : ImplicitExplicitODE ,
125+ ) -> Array :
126+ """Time stepping via forward and backward Euler methods.
127+
128+ This method is first order accurate.
129+
130+ Args:
131+ equation: equation to solve.
132+ time_step: time step.
133+
134+ Returns:
135+ Function that performs a time step.
136+ """
137+ F = equation .explicit_terms
138+ G = equation .implicit_terms
139+ G_inv = equation .implicit_solve
140+
141+ g = u + dt * F (u ) + 0.5 * dt * G (u )
142+ u = G_inv (g , 0.5 * dt )
143+
144+ return u
145+
146+
147+ def rk2_crank_nicolson (
148+ u : torch .Tensor ,
149+ dt : float ,
150+ equation : ImplicitExplicitODE ,
151+ ) -> Array :
152+ """Time stepping via Crank-Nicolson and 2nd order Runge-Kutta (Heun).
153+
154+ This method is second order accurate.
155+
156+ Args:
157+ equation: equation to solve.
158+ time_step: time step.
159+
160+ Returns:
161+ Function that performs a time step.
162+
163+ Reference:
164+ Chandler, G. J. & Kerswell, R. R. Invariant recurrent solutions embedded in
165+ a turbulent two-dimensional Kolmogorov flow. J. Fluid Mech. 722, 554–595
166+ (2013). https://doi.org/10.1017/jfm.2013.122 (Section 3)
167+ """
168+ F = equation .explicit_terms
169+ G = equation .implicit_terms
170+ G_inv = equation .implicit_solve
171+
172+ g = u + 0.5 * dt * G (u )
173+ h = F (u )
174+ u = G_inv (g + dt * h , 0.5 * dt )
175+ h = 0.5 * (F (u ) + h )
176+ u = G_inv (g + dt * h , 0.5 * dt )
177+ return u
178+
179+
64180def low_storage_runge_kutta_crank_nicolson (
65181 u : torch .Tensor ,
66182 dt : float ,
@@ -143,8 +259,8 @@ def crank_nicolson_rk4(
143259 return low_storage_runge_kutta_crank_nicolson (
144260 u ,
145261 dt = dt ,
146- params = params ,
147262 equation = equation ,
263+ params = params ,
148264 )
149265
150266
@@ -247,16 +363,23 @@ def vorticity_to_velocity(
247363 vxhat = two_pi_i * ky * psi_hat
248364 vyhat = - two_pi_i * kx * psi_hat
249365 return vxhat , vyhat
250-
251- def residual (self ,
366+
367+ def residual (
368+ self ,
252369 vort_hat : Array ,
253370 vort_t_hat : Array ,
254371 ):
255- residual = vort_t_hat - self .explicit_terms (vort_hat ) - self .viscosity * self .implicit_terms (vort_hat )
372+ residual = (
373+ vort_t_hat
374+ - self .explicit_terms (vort_hat )
375+ - self .viscosity * self .implicit_terms (vort_hat )
376+ )
256377 return residual
257378
258379 def _explicit_terms (self , vort_hat ):
259- vxhat , vyhat = self .vorticity_to_velocity (self .grid , vort_hat , (self .kx , self .ky ))
380+ vxhat , vyhat = self .vorticity_to_velocity (
381+ self .grid , vort_hat , (self .kx , self .ky )
382+ )
260383 vx , vy = fft .irfft2 (vxhat ), fft .irfft2 (vyhat )
261384
262385 grad_x_hat = 2j * torch .pi * self .kx * vort_hat
@@ -272,10 +395,14 @@ def _explicit_terms(self, vort_hat):
272395 terms = advection_hat
273396
274397 if self .forcing_fn is not None :
275- fx , fy = self .forcing_fn (self .grid , (vx , vy ))
276- fx_hat , fy_hat = fft .rfft2 (fx .data ), fft .rfft2 (fy .data )
277- terms += self .spectral_curl_2d ((fx_hat , fy_hat ), (self .kx , self .ky ))
278-
398+ if not self .forcing_fn .vorticity :
399+ fx , fy = self .forcing_fn (self .grid , (vx , vy ))
400+ fx_hat , fy_hat = fft .rfft2 (fx .data ), fft .rfft2 (fy .data )
401+ terms += self .spectral_curl_2d ((fx_hat , fy_hat ), (self .kx , self .ky ))
402+ else :
403+ f = self .forcing_fn (self .grid , vort_hat )
404+ f_hat = fft .rfft2 (f .data )
405+ terms += f_hat
279406 return terms
280407
281408 def explicit_terms (self , vort_hat ):
@@ -332,7 +459,8 @@ def get_trajectory(
332459 result = {
333460 var_name : torch .stack (var , dim = 0 )
334461 for var_name , var in zip (
335- ["vorticity" , "velocity" , "vort_t" , "residual" ], [w_all , v_all , dwdt_all , res_all ]
462+ ["vorticity" , "velocity" , "vort_t" , "residual" ],
463+ [w_all , v_all , dwdt_all , res_all ],
336464 )
337465 }
338466 return result
0 commit comments