1515# Modifications copyright (C) 2024 S.Cao
1616# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717
18- from typing import Tuple , Optional
18+ from typing import Optional , Tuple
19+
1920import torch
21+ import torch .nn as nn
22+
2023from . import grids
2124
2225Array = torch .Tensor
2326Grid = grids .Grid
2427GridArray = grids .GridArray
2528
26- def kolmogorov_forcing (
27- grid : Grid ,
28- v : Tuple [Array , Array ],
29- scale : float = 1 ,
30- k : int = 2 ,
31- swap_xy : bool = False ,
32- offsets : Optional [Tuple [Tuple [float , ...], ...]] = None ,
33- device : Optional [torch .device ] = None ,
34- ) -> Array :
35- """Returns the Kolmogorov forcing function for turbulence in 2D."""
36- if offsets is None :
37- offsets = grid .cell_faces
38- if grid .device is None and device is not None :
39- grid .device = device
40- if swap_xy :
41- x = grid .mesh (offsets [1 ])[0 ]
42- v = GridArray (scale * torch .sin (k * x ), offsets [1 ], grid )
43- u = GridArray (torch .zeros_like (v .data ), (1 , 1 / 2 ), grid )
44- f = (u , v )
45- else :
46- y = grid .mesh (offsets [0 ])[1 ]
47- u = GridArray (scale * torch .sin (k * y ), offsets [0 ], grid )
48- v = GridArray (torch .zeros_like (u .data ), (1 / 2 , 1 ), grid )
49- f = (u , v )
50- return f
29+
30+ class ForcingFn (nn .Module ):
31+ """
32+ A meta class for forcing functions
33+ """
34+
35+ def __init__ (
36+ self ,
37+ grid : Grid ,
38+ scale : float = 1 ,
39+ k : int = 1 ,
40+ diam : float = 1.0 ,
41+ swap_xy : bool = False ,
42+ offsets : Optional [Tuple [Tuple [float , ...], ...]] = None ,
43+ device : Optional [torch .device ] = None ,
44+ ** kwargs ,
45+ ):
46+ super ().__init__ ()
47+ self .grid = grid
48+ self .scale = scale
49+ self .k = k
50+ self .diam = diam
51+ self .swap_xy = swap_xy
52+ self .offsets = grid .cell_faces if offsets is None else offsets
53+ self .device = grid .device if device is None else device
54+
55+
56+ class KolmogorovForcing (ForcingFn ):
57+ """
58+ The Kolmogorov forcing function used in
59+ Sets up the flow that is used in Kochkov et al. [1].
60+ which is based on Boffetta et al. [2].
61+
62+ Note in the port: this forcing belongs a larger class
63+ of isotropic turbulence. See [3].
64+
65+ References:
66+ [1] Machine learning-accelerated computational fluid dynamics. Dmitrii
67+ Kochkov, Jamie A. Smith, Ayya Alieva, Qing Wang, Michael P. Brenner, Stephan
68+ Hoyer Proceedings of the National Academy of Sciences May 2021, 118 (21)
69+ e2101784118; DOI: 10.1073/pnas.2101784118.
70+ https://doi.org/10.1073/pnas.2101784118
71+
72+ [2] Boffetta, Guido, and Robert E. Ecke. "Two-dimensional turbulence."
73+ Annual review of fluid mechanics 44 (2012): 427-451.
74+ https://doi.org/10.1146/annurev-fluid-120710-101240
75+
76+ [3] McWilliams, J. C. (1984). "The emergence of isolated coherent vortices
77+ in turbulent flow". Journal of Fluid Mechanics, 146, 21-43.
78+ """
79+
80+ def __init__ (
81+ self ,
82+ diam = 2 * torch .pi ,
83+ offsets = ((0 , 0 ), (0 , 0 )),
84+ * args ,
85+ ** kwargs ,
86+ ):
87+ super ().__init__ (
88+ * args ,
89+ diam = diam ,
90+ offsets = offsets ,
91+ ** kwargs ,
92+ )
93+
94+ def forward (
95+ self ,
96+ grid : Optional [Grid ],
97+ velocity : Optional [Tuple [Array , Array ]] = None ,
98+ ) -> Tuple [Array , Array ]:
99+ offsets = self .offsets
100+ grid = self .grid if grid is None else grid
101+ domain_factor = 2 * torch .pi / self .diam
102+
103+ if self .swap_xy :
104+ x = grid .mesh (offsets [1 ])[0 ]
105+ v = GridArray (
106+ self .scale * torch .sin (self .k * domain_factor * x ), offsets [1 ], grid
107+ )
108+ u = GridArray (torch .zeros_like (v .data ), (1 , 1 / 2 ), grid )
109+ f = (u , v )
110+ else :
111+ y = grid .mesh (offsets [0 ])[1 ]
112+ u = GridArray (
113+ self .scale * torch .sin (self .k * domain_factor * y ), offsets [0 ], grid
114+ )
115+ v = GridArray (torch .zeros_like (u .data ), (1 / 2 , 1 ), grid )
116+ f = (u , v )
117+ return f
118+
119+ def potential_template (potential_func ):
120+ def wrapper (cls , x : Array , y : Array , s : float , k : float ) -> Array :
121+ return potential_func (x , y , s , k )
122+ return wrapper
123+
124+
125+ class SimpleSolenoidalForcing (ForcingFn ):
126+ """
127+ A simple solenoidal (rotating, divergence free) forcing function template.
128+ The template forcing is F = (-psi, psi) such that
129+
130+ Args:
131+ grid: grid on which to simulate the flow
132+ scale: a in the equation above, amplitude of the forcing
133+ k: k in the equation above, wavenumber of the forcing
134+ """
135+
136+ def __init__ (
137+ self ,
138+ scale = 1 ,
139+ diam = 1.0 ,
140+ k = 1.0 ,
141+ offsets = ((0 , 0 ), (0 , 0 )),
142+ * args ,
143+ ** kwargs ,
144+ ):
145+ super ().__init__ (
146+ * args ,
147+ scale = scale ,
148+ diam = diam ,
149+ k = k ,
150+ offsets = offsets ,
151+ ** kwargs ,
152+ )
153+
154+
155+ @potential_template
156+ def potential (* args , ** kwargs ) -> Array :
157+ raise NotImplementedError
158+
159+ def forward (
160+ self ,
161+ grid : Optional [Grid ],
162+ velocity : Optional [Tuple [Array , Array ]] = None ,
163+ ) -> Tuple [Array , Array ]:
164+ offsets = self .offsets
165+ grid = self .grid if grid is None else grid
166+ domain_factor = 2 * torch .pi / self .diam
167+ k = self .k * domain_factor
168+ scale = 0.5 * self .scale / (2 * torch .pi ) / self .k
169+
170+ if self .swap_xy :
171+ x = grid .mesh (offsets [1 ])[0 ]
172+ y = grid .mesh (offsets [0 ])[1 ]
173+ rot = self .potential (x , y , scale , k )
174+ v = GridArray (rot , offsets [1 ], grid )
175+ u = GridArray (- rot , (1 , 1 / 2 ), grid )
176+ f = (u , v )
177+ else :
178+ x = grid .mesh (offsets [0 ])[0 ]
179+ y = grid .mesh (offsets [1 ])[1 ]
180+ rot = self .potential (x , y , scale , k )
181+ u = GridArray (rot , offsets [0 ], grid )
182+ v = GridArray (- rot , (1 / 2 , 1 ), grid )
183+ f = (u , v )
184+ return f
185+
186+
187+ class SinCosForcing (SimpleSolenoidalForcing ):
188+ """
189+ The solenoidal (divergence free) forcing function used in [4].
190+
191+ Note: in the vorticity-streamfunction formulation, the forcing
192+ is actually the curl of the velocity field, which
193+ is a*(sin(2*pi*k*(x+y)) + cos(2*pi*k*(x+y)))
194+ a=0.1, k=1 in [4]
195+
196+ References:
197+ [4] Li, Zongyi, et al. "Fourier Neural Operator for
198+ Parametric Partial Differential Equations."
199+ ICLR. 2020.
200+
201+ Args:
202+ grid: grid on which to simulate the flow
203+ scale: a in the equation above, amplitude of the forcing
204+ k: k in the equation above, wavenumber of the forcing
205+ """
206+
207+ def __init__ (
208+ self ,
209+ scale = 0.1 ,
210+ diam = 1.0 ,
211+ k = 1.0 ,
212+ offsets = ((0 , 0 ), (0 , 0 )),
213+ * args ,
214+ ** kwargs ,
215+ ):
216+ super ().__init__ (
217+ * args ,
218+ scale = scale ,
219+ diam = diam ,
220+ k = k ,
221+ offsets = offsets ,
222+ ** kwargs ,
223+ )
224+
225+ @potential_template
226+ def potential (x : Array , y : Array , s : float , k : float ) -> Array :
227+ return s * (torch .sin (k * (x + y )) - torch .cos (k * (x + y )))
0 commit comments