4242import torch
4343from torch_cfd import boundaries , grids
4444
45- ArrayVector = Sequence [torch .Tensor ]
45+ ArrayVector = List [torch .Tensor ]
4646GridVariable = grids .GridVariable
4747GridTensor = grids .GridTensor
4848GridVariableVector = Union [grids .GridVariableVector , Sequence [grids .GridVariable ]]
@@ -159,20 +159,22 @@ def set_laplacian_matrix(
159159 grid : grids .Grid ,
160160 bc : boundaries .BoundaryConditions ,
161161 device : Optional [torch .device ] = None ,
162+ dtype : torch .dtype = torch .float32 ,
162163) -> ArrayVector :
163164 """Initialize the Laplacian operators."""
164165
165166 offset = grid .cell_center
166- return laplacian_matrix_w_boundaries (grid , offset = offset , bc = bc , device = device )
167+ return laplacian_matrix_w_boundaries (grid , offset = offset , bc = bc , device = device , dtype = dtype )
167168
168169
169- def laplacian_matrix (n : int , step : float , sparse : bool = False ) -> torch .Tensor :
170+ def laplacian_matrix (n : int , step : float , sparse : bool = False , dtype = torch . float32 ) -> torch .Tensor :
170171 """
171172 Create 1D Laplacian operator matrix, with periodic BC.
172- modified the scipy.linalg.circulant implementation to native torch
173+ The matrix is a tri-diagonal matrix with [1, -2, 1]/h**2
174+ Modified the scipy.linalg.circulant implementation to native torch
173175 """
174176 if sparse :
175- values = torch .tensor ([1.0 , - 2.0 , 1.0 ]) / step ** 2
177+ values = torch .tensor ([1.0 , - 2.0 , 1.0 ], dtype = dtype ) / step ** 2
176178 idx_row = torch .arange (n ).repeat (3 )
177179 idx_col = torch .cat (
178180 [
@@ -188,33 +190,45 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor:
188190 )
189191 return torch .sparse_coo_tensor (indices , data , size = (n , n ))
190192 else :
191- column = torch .zeros (n )
193+ column = torch .zeros (n , dtype = dtype )
192194 column [0 ] = - 2 / step ** 2
193195 column [1 ] = column [- 1 ] = 1 / step ** 2
194196 idx = (n - torch .arange (n )[None ].T + torch .arange (n )[None ]) % n
195197 return torch .gather (column [None , ...].expand (n , - 1 ), 1 , idx )
196198
197199
198200def _laplacian_boundary_dirichlet_cell_centered (
199- laplacians : ArrayVector , grid : grids .Grid , axis : int , side : str
201+ laplacians : ArrayVector , grid : grids .Grid , dim : int , side : str
200202) -> None :
201203 """Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.
202204
203205 laplacians[i] contains a 3 point stencil matrix L that approximates
204206 d^2/dx_i^2.
205207 For detailed documentation on laplacians input type see
206- array_utils.laplacian_matrix.
207- The default return of array_utils.laplacian_matrix makes a matrix for
208- periodic boundary. For dirichlet boundary, the correct equation is
209- L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
208+ fdm.laplacian_matrix.
209+ The default return of fdm.laplacian_matrix makes a matrix for
210+ periodic boundary. For (homogeneous) dirichlet boundary, the correct equation is
211+ L(u_interior) = rhs_interior
212+ BL_boundary = u_fixed_boundary.
213+ So
210214 laplacian_boundary_dirichlet restricts the matrix L to
211- interior points only.
215+ interior points only.
216+
217+ Denote the node in the 3-pt stencil as
218+ u[ghost], u[boundary], u[interior] = u[0], u[1], u[2].
219+ The original stencil on the boundary is
220+ [1, -2, 1] * [u[0], u[1], u[2]] = u[0] - 2*u[1] + u[2]
221+ In the homogeneous Dirichlet bc case if the offset
222+ is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
223+ 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
224+ The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -3/h**2,
225+ thus 1/h**2 is subtracted from the diagonal, and the ghost cell dof is set to zero (Lap[0, -1])
212226
213227 This function assumes RHS has cell-centered offset.
214228 Args:
215229 laplacians: list of 1d laplacians
216230 grid: grid object
217- axis : axis along which to impose dirichlet bc.
231+ dim : axis along which to impose dirichlet bc.
218232 side: lower or upper side to assign boundary to.
219233
220234 Returns:
@@ -223,52 +237,50 @@ def _laplacian_boundary_dirichlet_cell_centered(
223237 TODO:
224238 [ ]: this function is not implemented in the original Jax-CFD code.
225239 """
226- # This function assumes homogeneous boundary, in which case if the offset
227- # is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
228- # 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
240+
229241 if side == "lower" :
230- laplacians [axis ][0 , 0 ] = laplacians [axis ][0 , 0 ] - 1 / grid .step [axis ] ** 2
242+ laplacians [dim ][0 , 0 ] = laplacians [dim ][0 , 0 ] - 1 / grid .step [dim ] ** 2
231243 else :
232- laplacians [axis ][- 1 , - 1 ] = laplacians [axis ][- 1 , - 1 ] - 1 / grid .step [axis ] ** 2
244+ laplacians [dim ][- 1 , - 1 ] = laplacians [dim ][- 1 , - 1 ] - 1 / grid .step [dim ] ** 2
233245 # deletes corner dependencies on the "looped-around" part.
234246 # this should be done irrespective of which side, since one boundary cannot
235247 # be periodic while the other is.
236- laplacians [axis ][0 , - 1 ] = 0.0
237- laplacians [axis ][- 1 , 0 ] = 0.0
238- return laplacians
248+ laplacians [dim ][0 , - 1 ] = 0.0
249+ laplacians [dim ][- 1 , 0 ] = 0.0
250+ return
239251
240252
241253def _laplacian_boundary_neumann_cell_centered (
242- laplacians : List [ Any ] , grid : grids .Grid , axis : int , side : str
254+ laplacians : ArrayVector , grid : grids .Grid , dim : int , side : str
243255) -> None :
244256 """Converts 1d laplacian matrix to satisfy neumann homogeneous bc.
245257
246258 This function assumes the RHS will have a cell-centered offset.
247259 Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
248- code.
260+ code. For homogeneous Neumann BC (du/dn = 0), the ghost cell should equal the interior cell: u[ghost] = u[1]. The stencil becomes:
261+ [1, -2, 1] * [u[1], u[1], u[2]] = u[1] - 2*u[1] + u[2] = -u[1] + u[2]
262+ The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -1/h**2,
263+ thus 1/h**2 is added to the diagonal, and the ghost cell dof is set to zero (Lap[0, -1]).
249264
250265 Args:
251266 laplacians: list of 1d laplacians
252267 grid: grid object
253- axis : axis along which to impose dirichlet bc.
268+ dim : axis along which to impose dirichlet bc.
254269 side: which boundary side to convert to neumann homogeneous bc.
255270
256271 Returns:
257272 updated list of 1d laplacians.
258-
259- TODO
260- [ ]: this function is not implemented in the original Jax-CFD code.
261273 """
262274 if side == "lower" :
263- laplacians [axis ][0 , 0 ] = laplacians [axis ][0 , 0 ] + 1 / grid .step [axis ] ** 2
275+ laplacians [dim ][0 , 0 ] = laplacians [dim ][0 , 0 ] + 1 / grid .step [dim ] ** 2
264276 else :
265- laplacians [axis ][- 1 , - 1 ] = laplacians [axis ][- 1 , - 1 ] + 1 / grid .step [axis ] ** 2
277+ laplacians [dim ][- 1 , - 1 ] = laplacians [dim ][- 1 , - 1 ] + 1 / grid .step [dim ] ** 2
266278 # deletes corner dependencies on the "looped-around" part.
267279 # this should be done irrespective of which side, since one boundary cannot
268280 # be periodic while the other is.
269- laplacians [axis ][0 , - 1 ] = 0.0
270- laplacians [axis ][- 1 , 0 ] = 0.0
271- return laplacians
281+ laplacians [dim ][0 , - 1 ] = 0.0
282+ laplacians [dim ][- 1 , 0 ] = 0.0
283+ return
272284
273285
274286def laplacian_matrix_w_boundaries (
@@ -277,6 +289,7 @@ def laplacian_matrix_w_boundaries(
277289 bc : grids .BoundaryConditions ,
278290 laplacians : Optional [ArrayVector ] = None ,
279291 device : Optional [torch .device ] = None ,
292+ dtype : torch .dtype = torch .float32 ,
280293 sparse : bool = False ,
281294) -> ArrayVector :
282295 """Returns 1d laplacians that satisfy boundary conditions bc on grid.
@@ -323,11 +336,13 @@ def laplacian_matrix_w_boundaries(
323336 raise NotImplementedError (
324337 "edge-aligned Neumann boundaries are not implemented."
325338 )
326- return list (lap .to (device ) for lap in laplacians ) if device else laplacians
339+ return list (lap .to (dtype ). to ( device ) for lap in laplacians )
327340
328341
329342def _linear_along_axis (c : GridVariable , offset : float , dim : int ) -> GridVariable :
330- """Linear interpolation of `c` to `offset` along a single specified `axis`."""
343+ """Linear interpolation of `c` to `offset` along a single specified `axis`.
344+ dim here is >= 0, the negative indexing for batched implementation is handled by grids.shift.
345+ """
331346 offset_delta = offset - c .offset [dim ]
332347
333348 # If offsets are the same, `c` is unchanged.
@@ -383,8 +398,8 @@ def linear(
383398 f"got { c .offset } and { offset } ."
384399 )
385400 interpolated = c
386- for a , o in enumerate (offset ):
387- interpolated = _linear_along_axis (interpolated , offset = o , dim = a )
401+ for dim , o in enumerate (offset ):
402+ interpolated = _linear_along_axis (interpolated , offset = o , dim = dim )
388403 return interpolated
389404
390405
@@ -405,15 +420,15 @@ def gradient_tensor(v):
405420 if not isinstance (v , GridVariable ):
406421 return GridTensor (torch .stack ([gradient_tensor (u ) for u in v ], dim = - 1 ))
407422 grad = []
408- for axis in range (v .grid .ndim ):
409- offset = v .offset [axis ]
423+ for dim in range (- v .grid .ndim , 0 ):
424+ offset = v .offset [dim ]
410425 if offset == 0 :
411- derivative = forward_difference (v , axis )
426+ derivative = forward_difference (v , dim )
412427 elif offset == 1 :
413- derivative = backward_difference (v , axis )
428+ derivative = backward_difference (v , dim )
414429 elif offset == 0.5 :
415430 v_centered = linear (v , v .grid .cell_center )
416- derivative = central_difference (v_centered , axis )
431+ derivative = central_difference (v_centered , dim )
417432 else :
418433 raise ValueError (f"expected offset values in {{0, 0.5, 1}}, got { offset } " )
419434 grad .append (derivative )
@@ -427,4 +442,4 @@ def curl_2d(v: GridVariableVector) -> GridVariable:
427442 grid = grids .consistent_grid_arrays (* v )
428443 if grid .ndim != 2 :
429444 raise ValueError (f"Grid dimensionality is not 2: { grid .ndim } " )
430- return forward_difference (v [1 ], dim = 0 ) - forward_difference (v [0 ], dim = 1 )
445+ return forward_difference (v [1 ], dim = - 2 ) - forward_difference (v [0 ], dim = - 1 )
0 commit comments