1616# ported Google's Jax-CFD functional template to PyTorch's tensor ops
1717
1818"""Prepare initial conditions for simulations."""
19- from typing import Callable , Optional , Sequence
2019import math
20+ from typing import Callable , Optional , Sequence
21+
2122import torch
2223import torch .fft as fft
23- from . import grids
24- from . import finite_differences as fd
25- from . import fast_diagonalization as solver
24+
25+ from . import grids , pressure
2626
2727Array = torch .Tensor
2828GridArray = grids .GridArray
@@ -45,6 +45,7 @@ def wrap_velocities(
4545 for u , offset , bc in zip (v , grid .cell_faces , bcs )
4646 )
4747
48+
4849def wrap_vorticity (
4950 w : Array ,
5051 grid : grids .Grid ,
@@ -57,7 +58,11 @@ def wrap_vorticity(
5758
5859
5960def _log_normal_density (k , mode : float , variance = 0.25 ):
60- """Unscaled PDF for a log normal given `mode` and log variance 1."""
61+ """
62+ Unscaled PDF for a log normal given `mode` and log variance 1.
63+
64+
65+ """
6166 mean = math .log (mode ) + variance
6267 logk = torch .log (k )
6368 return torch .exp (- ((mean - logk ) ** 2 ) / 2 / variance - logk )
@@ -74,6 +79,7 @@ def McWilliams_density(k, mode: float, tau: float = 1.0):
7479 """
7580 return (k * (tau ** 2 + (k / mode ) ** 4 )) ** (- 1 )
7681
82+
7783def _angular_frequency_magnitude (grid : grids .Grid ) -> Array :
7884 frequencies = [
7985 2 * torch .pi * fft .fftfreq (size , step )
@@ -95,103 +101,19 @@ def spectral_filter(
95101 # real, because our spectral density only depends on norm(k).
96102 return fft .ifftn (fft .fftn (v ) * filters ).real
97103
104+
98105def streamfunc_normalize (k , psi ):
99- # only half the spectrum for real ffts, needs spectral normalisation
100106 nx , ny = psi .shape
101107 psih = fft .fft2 (psi )
102- uh = k * psih
103- kinetic_energy = (2 * uh .abs () ** 2 / (nx * ny ) ** 2 ).sum ()
108+ uh_mag = k * psih
109+ kinetic_energy = (2 * uh_mag .abs () ** 2 / (nx * ny ) ** 2 ).sum ()
104110 return psi / kinetic_energy .sqrt ()
105111
106- def _rhs_transform (
107- u : GridArray ,
108- bc : BoundaryConditions ,
109- ) -> Array :
110- """Transform the RHS of pressure projection equation for stability.
111-
112- In case of poisson equation, the kernel is subtracted from RHS for stability.
113-
114- Args:
115- u: a GridArray that solves ∇²x = u.
116- bc: specifies boundary of x.
117-
118- Returns:
119- u' s.t. u = u' + kernel of the laplacian.
120- """
121- u_data = u .data
122- for axis in range (u .grid .ndim ):
123- if (
124- bc .types [axis ][0 ] == grids .BCType .NEUMANN
125- and bc .types [axis ][1 ] == grids .BCType .NEUMANN
126- ):
127- # if all sides are neumann, poisson solution has a kernel of constant
128- # functions. We substact the mean to ensure consistency.
129- u_data = u_data - torch .mean (u_data )
130- return u_data
131-
132-
133- def solve_fast_diag (
134- v : GridVariableVector ,
135- q0 : Optional [GridVariable ] = None ,
136- pressure_bc : Optional [grids .ConstantBoundaryConditions ] = None ,
137- implementation : Optional [str ] = None ,
138- ) -> GridArray :
139- """Solve for pressure using the fast diagonalization approach."""
140- del q0 # unused
141- if pressure_bc is None :
142- pressure_bc = grids .get_pressure_bc_from_velocity (v )
143- if grids .has_all_periodic_boundary_conditions (* v ):
144- circulant = True
145- else :
146- circulant = False
147- # only matmul implementation supports non-circulant matrices
148- implementation = "matmul"
149- grid = grids .consistent_grid (* v )
150- rhs = fd .divergence (v )
151- laplacians = list (map (fd .laplacian_matrix , grid .shape , grid .step ))
152- laplacians = [lap .to (grid .device ) for lap in laplacians ]
153- rhs_transformed = _rhs_transform (rhs , pressure_bc )
154- pinv = solver .pseudoinverse (
155- rhs_transformed ,
156- laplacians ,
157- rhs_transformed .dtype ,
158- hermitian = True ,
159- circulant = circulant ,
160- implementation = implementation ,
161- )
162- # return applied(pinv)(rhs_transformed)
163- return GridArray (pinv , rhs .offset , rhs .grid )
164-
165-
166- def projection (
167- v : GridVariableVector ,
168- solve : Callable = solve_fast_diag ,
169- ) -> GridVariableVector :
170- """
171- Apply pressure projection (a discrete Helmholtz decomposition)
172- to make a velocity field divergence free.
173-
174- Note by S.Cao: this was originally implemented by the jax-cfd team
175- but using FDM results having a non-negligible error in fp32.
176- One resolution is to use fp64 then cast back to fp32.
177- """
178- grid = grids .consistent_grid (* v )
179- pressure_bc = grids .get_pressure_bc_from_velocity (v )
180-
181- q0 = GridArray (torch .zeros (grid .shape ), grid .cell_center , grid )
182- q0 = pressure_bc .impose_bc (q0 )
183-
184- q = solve (v , q0 , pressure_bc )
185- q = pressure_bc .impose_bc (q )
186- q_grad = fd .forward_difference (q )
187- v_projected = tuple (u .bc .impose_bc (u .array - q_g ) for u , q_g in zip (v , q_grad ))
188- return v_projected
189-
190112
191113def project_and_normalize (
192114 v : GridVariableVector , maximum_velocity : float = 1
193115) -> GridVariableVector :
194- v = projection (v )
116+ v = pressure . projection (v )
195117 vmax = torch .linalg .norm (torch .stack ([u .data for u in v ]), dim = 0 ).max ()
196118 v = tuple (GridVariable (maximum_velocity * u .array / vmax , u .bc ) for u in v )
197119 return v
@@ -256,7 +178,6 @@ def vorticity_field(
256178 Args:
257179 rng_key: key for seeding the random initial vorticity field.
258180 grid: the grid on which the vorticity field is defined.
259- maximum_velocity: the maximum speed in the velocity field.
260181 peak_wavenumber: the velocity field will be filtered so that the largest
261182 magnitudes are associated with this wavenumber.
262183
@@ -277,4 +198,4 @@ def spectral_density(k):
277198 boundary_condition = grids .periodic_boundary_conditions (grid .ndim )
278199 vorticity = wrap_vorticity (vorticity , grid , boundary_condition )
279200
280- return vorticity
201+ return vorticity
0 commit comments