2121import torch .nn as nn
2222import torch .fft as fft
2323from . import grids
24+ from tqdm .auto import tqdm
2425
26+ TQDM_ITERS = 100
2527
2628Array = torch .Tensor
2729Grid = grids .Grid
@@ -59,6 +61,93 @@ def implicit_solve(
5961 raise NotImplementedError
6062
6163
64+ def low_storage_runge_kutta_crank_nicolson (
65+ u : torch .Tensor ,
66+ dt : float ,
67+ params : Dict ,
68+ equation : ImplicitExplicitODE ,
69+ ) -> Array :
70+ """
71+ ported from jax functional programming to be tensor2tensor
72+ Time stepping via "low-storage" Runge-Kutta and Crank-Nicolson steps.
73+
74+ These scheme are second order accurate for the implicit terms, but potentially
75+ higher order accurate for the explicit terms. This seems to be a favorable
76+ tradeoff when the explicit terms dominate, e.g., for modeling turbulent
77+ fluids.
78+
79+ Per Canuto: "[these methods] have been widely used for the time-discretization
80+ in applications of spectral methods."
81+
82+ Args:
83+ alphas: alpha coefficients.
84+ betas: beta coefficients.
85+ gammas: gamma coefficients.
86+ equation.F: explicit terms (convection, rhs, drag).
87+ equation.G: implicit terms (diffusion).
88+ equation.implicit_solve: implicit solver, when evaluates at an input (B, n, n), outputs (B, n, n).
89+ dt: time step.
90+
91+ Input: w^{t_i} (B, n, n)
92+ Returns: w^{t_{i+1}} (B, n, n)
93+
94+ Reference:
95+ Canuto, C., Yousuff Hussaini, M., Quarteroni, A. & Zang, T. A.
96+ Spectral Methods: Evolution to Complex Geometries and Applications to
97+ Fluid Dynamics. (Springer Berlin Heidelberg, 2007).
98+ https://doi.org/10.1007/978-3-540-30728-0 (Appendix D.3)
99+ """
100+ dt = dt
101+ alphas = params ["alphas" ]
102+ betas = params ["betas" ]
103+ gammas = params ["gammas" ]
104+ F = equation .explicit_terms
105+ G = equation .implicit_terms
106+ G_inv = equation .implicit_solve
107+
108+ if len (alphas ) - 1 != len (betas ) != len (gammas ):
109+ raise ValueError ("number of RK coefficients does not match" )
110+
111+ h = 0
112+ for k in range (len (betas )):
113+ h = F (u ) + betas [k ] * h
114+ mu = 0.5 * dt * (alphas [k + 1 ] - alphas [k ])
115+ u = G_inv (u + gammas [k ] * dt * h + mu * G (u ), mu )
116+ return u
117+
118+
119+ def crank_nicolson_rk4 (
120+ u : Array ,
121+ dt : float ,
122+ equation : ImplicitExplicitODE ,
123+ ) -> Array :
124+ """Time stepping via Crank-Nicolson and RK4 ("Carpenter-Kennedy")."""
125+ params = dict (
126+ alphas = [
127+ 0 ,
128+ 0.1496590219993 ,
129+ 0.3704009573644 ,
130+ 0.6222557631345 ,
131+ 0.9582821306748 ,
132+ 1 ,
133+ ],
134+ betas = [0 , - 0.4178904745 , - 1.192151694643 , - 1.697784692471 , - 1.514183444257 ],
135+ gammas = [
136+ 0.1496590219993 ,
137+ 0.3792103129999 ,
138+ 0.8229550293869 ,
139+ 0.6994504559488 ,
140+ 0.1530572479681 ,
141+ ],
142+ )
143+ return low_storage_runge_kutta_crank_nicolson (
144+ u ,
145+ dt = dt ,
146+ params = params ,
147+ equation = equation ,
148+ )
149+
150+
62151class NavierStokes2D (nn .Module ):
63152 """Breaks the Navier-Stokes equation into implicit and explicit parts.
64153
@@ -80,13 +169,15 @@ def __init__(
80169 drag : float = 0.0 ,
81170 smooth : bool = True ,
82171 forcing_fn : Optional [Callable ] = None ,
172+ solver : Optional [Callable ] = crank_nicolson_rk4 ,
83173 ):
84174 super ().__init__ ()
85175 self .viscosity = viscosity
86176 self .grid = grid
87177 self .drag = drag
88178 self .smooth = smooth
89179 self .forcing_fn = forcing_fn
180+ self .solver = solver
90181 self ._initialize ()
91182
92183 def _initialize (self ):
@@ -172,109 +263,63 @@ def _explicit_terms(self, vort_hat):
172263
173264 return terms
174265
175- def explicit_terms (self ):
176- return lambda vort_hat : self ._explicit_terms (vort_hat )
266+ def explicit_terms (self , vort_hat ):
267+ return self ._explicit_terms (vort_hat )
177268
178- def implicit_terms (self ):
179- return lambda vort_hat : self .linear_term * vort_hat
269+ def implicit_terms (self , vort_hat ):
270+ return self .linear_term * vort_hat
180271
181- def implicit_solve (self , time_step ):
182- return lambda vort_hat : 1 / (1 - time_step * self .linear_term ) * vort_hat
272+ def implicit_solve (self , vort_hat , dt ):
273+ return 1 / (1 - dt * self .linear_term ) * vort_hat
183274
184- def step (self , time_step ):
275+ def get_trajectory (
276+ self ,
277+ w0 : Array ,
278+ dt : float ,
279+ time_steps : int ,
280+ record_every_steps = 1 ,
281+ pbar = False ,
282+ pbar_desc = "" ,
283+ require_grad = False ,
284+ ):
185285 """
186- this is for tests
286+ vorticity stacked in the time dimension
187287 """
188- return lambda w : self .explicit_terms ()(w ) + self .implicit_solve (
189- time_step = time_step
190- )(w )
288+ w_all = []
289+ v_all = []
290+ dwdt_all = []
291+ w = w0
292+ update_iters = time_steps // TQDM_ITERS
293+ with tqdm (total = time_steps ) as pbar :
294+ for t in range (time_steps ):
295+ w , dwdt = self .forward (w , dt = dt )
296+ w .requires_grad_ (require_grad )
297+ dwdt .requires_grad_ (require_grad )
298+
299+ if t % update_iters == 0 :
300+ pbar .set_description (pbar_desc )
301+ pbar .update (update_iters )
302+
303+ if t % record_every_steps == 0 :
304+ w_ = w .detach ().clone ()
305+ dwdt_ = dwdt .detach ().clone ()
306+ v = self .vorticity_to_velocity (self .grid , w_ )
307+ v = torch .stack (v , dim = 0 )
308+ w_all .append (w_ )
309+ v_all .append (v )
310+ dwdt_all .append (dwdt_ )
311+ result = {
312+ var_name : torch .stack (var , dim = 0 )
313+ for var_name , var in zip (
314+ ["vorticity" , "velocity" , "vort_t" ], [w_all , v_all , dwdt_all ]
315+ )
316+ }
317+ return result
318+
319+ def step (self , * args , ** kwargs ):
320+ return self .forward (* args , ** kwargs )
191321
192322 def forward (self , vort_hat , dt ):
193- return crank_nicolson_rk4 (self , vort_hat , dt )
194-
195-
196- def low_storage_runge_kutta_crank_nicolson (
197- u : torch .Tensor ,
198- params : Dict ,
199- equation : ImplicitExplicitODE ,
200- time_step : float ,
201- ) -> Array :
202- """
203- ported from jax functional programming to be tensor2tensor
204- Time stepping via "low-storage" Runge-Kutta and Crank-Nicolson steps.
205-
206- These scheme are second order accurate for the implicit terms, but potentially
207- higher order accurate for the explicit terms. This seems to be a favorable
208- tradeoff when the explicit terms dominate, e.g., for modeling turbulent
209- fluids.
210-
211- Per Canuto: "[these methods] have been widely used for the time-discretization
212- in applications of spectral methods."
213-
214- Args:
215- alphas: alpha coefficients.
216- betas: beta coefficients.
217- gammas: gamma coefficients.
218- equation.F: explicit terms (convection, rhs, drag).
219- equation.G: implicit terms (diffusion).
220- equation.implicit_solve: implicit solver, when evaluates at an input (B, n, n), outputs (B, n, n).
221- time_step: time step.
222-
223- Input: w^{t_i} (B, n, n)
224- Returns: w^{t_{i+1}} (B, n, n)
225-
226- Reference:
227- Canuto, C., Yousuff Hussaini, M., Quarteroni, A. & Zang, T. A.
228- Spectral Methods: Evolution to Complex Geometries and Applications to
229- Fluid Dynamics. (Springer Berlin Heidelberg, 2007).
230- https://doi.org/10.1007/978-3-540-30728-0 (Appendix D.3)
231- """
232- dt = time_step
233- alphas = params ["alphas" ]
234- betas = params ["betas" ]
235- gammas = params ["gammas" ]
236- F = equation .explicit_terms ()
237- G = equation .implicit_terms ()
238- G_inv = equation .implicit_solve
239-
240- if len (alphas ) - 1 != len (betas ) != len (gammas ):
241- raise ValueError ("number of RK coefficients does not match" )
242-
243- h = 0
244- for k in range (len (betas )):
245- h = F (u ) + betas [k ] * h
246- mu = 0.5 * dt * (alphas [k + 1 ] - alphas [k ])
247- u = G_inv (mu )(u + gammas [k ] * dt * h + mu * G (u ))
248- return u
249-
250-
251- def crank_nicolson_rk4 (
252- equation : ImplicitExplicitODE ,
253- u : Array ,
254- time_step : float ,
255- ) -> Array :
256- """Time stepping via Crank-Nicolson and RK4 ("Carpenter-Kennedy")."""
257- params = dict (
258- alphas = [
259- 0 ,
260- 0.1496590219993 ,
261- 0.3704009573644 ,
262- 0.6222557631345 ,
263- 0.9582821306748 ,
264- 1 ,
265- ],
266- betas = [0 , - 0.4178904745 , - 1.192151694643 , - 1.697784692471 , - 1.514183444257 ],
267- gammas = [
268- 0.1496590219993 ,
269- 0.3792103129999 ,
270- 0.8229550293869 ,
271- 0.6994504559488 ,
272- 0.1530572479681 ,
273- ],
274- )
275- return low_storage_runge_kutta_crank_nicolson (
276- u ,
277- params = params ,
278- equation = equation ,
279- time_step = time_step ,
280- )
323+ vort_hat_new = self .solver (vort_hat , dt , self )
324+ dvortdt_hat = 1 / dt * (vort_hat_new - vort_hat )
325+ return vort_hat_new , dvortdt_hat
0 commit comments