Skip to content

Commit 3724f9f

Browse files
authored
Merge pull request #6 from scaomath/0.2.4-dev
Added multigrid solvers for pressure projection and tests
2 parents d6968f2 + 4ffc74f commit 3724f9f

File tree

12 files changed

+2061
-596
lines changed

12 files changed

+2061
-596
lines changed

torch_cfd/advection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable
282282
flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v)))
283283

284284
# wrap flux with boundary conditions to flux if not periodic
285+
# flux = GridVariableVector(
286+
# tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs))
287+
# )
285288
flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)))
286289

287290

torch_cfd/boundaries.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def pad_and_impose_bc(
384384
)
385385
return GridVariable(u.data, u.offset, u.grid, self)
386386

387-
def impose_bc(self, u: GridVariable, mode: str="") -> GridVariable:
387+
def impose_bc(self, u: GridVariable, mode: str = "") -> GridVariable:
388388
"""Returns GridVariable with correct boundary condition.
389389
390390
Some grid points of GridVariable might coincide with boundary. This ensures
@@ -435,12 +435,32 @@ def is_bc_periodic_boundary_conditions(bc: BoundaryConditions, dim: int) -> bool
435435
)
436436
return True
437437

438+
def is_bc_all_periodic_boundary_conditions(bc: BoundaryConditions) -> bool:
439+
"""Returns true if scalar has periodic bc along all axes."""
440+
for dim in range(bc.ndim):
441+
if not is_bc_periodic_boundary_conditions(bc, dim):
442+
return False
443+
return True
444+
438445

439446
def is_periodic_boundary_conditions(c: GridVariable, dim: int) -> bool:
440447
"""Returns true if scalar has periodic bc along axis."""
441448
return is_bc_periodic_boundary_conditions(c.bc, dim)
442449

443450

451+
def is_bc_pure_neumann_boundary_conditions(bc: BoundaryConditions) -> bool:
452+
"""Returns true if scalar has pure Neumann bc along all axes."""
453+
for dim in range(bc.ndim):
454+
if bc.types[dim][0] != BCType.NEUMANN or bc.types[dim][1] != BCType.NEUMANN:
455+
return False
456+
return True
457+
458+
459+
def is_pure_neumann_boundary_conditions(c: GridVariable) -> bool:
460+
"""Returns true if scalar has pure Neumann bc along all axes."""
461+
return is_bc_pure_neumann_boundary_conditions(c.bc)
462+
463+
444464
# Convenience utilities to ease updating of BoundaryConditions implementation
445465
def periodic_boundary_conditions(ndim: int) -> BoundaryConditions:
446466
"""Returns periodic BCs for a variable with `ndim` spatial dimension."""

torch_cfd/finite_differences.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import torch
4343
from torch_cfd import boundaries, grids
4444

45-
ArrayVector = Sequence[torch.Tensor]
45+
ArrayVector = List[torch.Tensor]
4646
GridVariable = grids.GridVariable
4747
GridTensor = grids.GridTensor
4848
GridVariableVector = Union[grids.GridVariableVector, Sequence[grids.GridVariable]]
@@ -159,20 +159,22 @@ def set_laplacian_matrix(
159159
grid: grids.Grid,
160160
bc: boundaries.BoundaryConditions,
161161
device: Optional[torch.device] = None,
162+
dtype: torch.dtype = torch.float32,
162163
) -> ArrayVector:
163164
"""Initialize the Laplacian operators."""
164165

165166
offset = grid.cell_center
166-
return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device)
167+
return laplacian_matrix_w_boundaries(grid, offset=offset, bc=bc, device=device, dtype=dtype)
167168

168169

169-
def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor:
170+
def laplacian_matrix(n: int, step: float, sparse: bool = False, dtype=torch.float32) -> torch.Tensor:
170171
"""
171172
Create 1D Laplacian operator matrix, with periodic BC.
172-
modified the scipy.linalg.circulant implementation to native torch
173+
The matrix is a tri-diagonal matrix with [1, -2, 1]/h**2
174+
Modified the scipy.linalg.circulant implementation to native torch
173175
"""
174176
if sparse:
175-
values = torch.tensor([1.0, -2.0, 1.0]) / step**2
177+
values = torch.tensor([1.0, -2.0, 1.0], dtype=dtype) / step**2
176178
idx_row = torch.arange(n).repeat(3)
177179
idx_col = torch.cat(
178180
[
@@ -188,33 +190,45 @@ def laplacian_matrix(n: int, step: float, sparse: bool = False) -> torch.Tensor:
188190
)
189191
return torch.sparse_coo_tensor(indices, data, size=(n, n))
190192
else:
191-
column = torch.zeros(n)
193+
column = torch.zeros(n, dtype=dtype)
192194
column[0] = -2 / step**2
193195
column[1] = column[-1] = 1 / step**2
194196
idx = (n - torch.arange(n)[None].T + torch.arange(n)[None]) % n
195197
return torch.gather(column[None, ...].expand(n, -1), 1, idx)
196198

197199

198200
def _laplacian_boundary_dirichlet_cell_centered(
199-
laplacians: ArrayVector, grid: grids.Grid, axis: int, side: str
201+
laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str
200202
) -> None:
201203
"""Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.
202204
203205
laplacians[i] contains a 3 point stencil matrix L that approximates
204206
d^2/dx_i^2.
205207
For detailed documentation on laplacians input type see
206-
array_utils.laplacian_matrix.
207-
The default return of array_utils.laplacian_matrix makes a matrix for
208-
periodic boundary. For dirichlet boundary, the correct equation is
209-
L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
208+
fdm.laplacian_matrix.
209+
The default return of fdm.laplacian_matrix makes a matrix for
210+
periodic boundary. For (homogeneous) dirichlet boundary, the correct equation is
211+
L(u_interior) = rhs_interior
212+
BL_boundary = u_fixed_boundary.
213+
So
210214
laplacian_boundary_dirichlet restricts the matrix L to
211-
interior points only.
215+
interior points only.
216+
217+
Denote the node in the 3-pt stencil as
218+
u[ghost], u[boundary], u[interior] = u[0], u[1], u[2].
219+
The original stencil on the boundary is
220+
[1, -2, 1] * [u[0], u[1], u[2]] = u[0] - 2*u[1] + u[2]
221+
In the homogeneous Dirichlet bc case if the offset
222+
is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
223+
3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
224+
The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -3/h**2,
225+
thus 1/h**2 is subtracted from the diagonal, and the ghost cell dof is set to zero (Lap[0, -1])
212226
213227
This function assumes RHS has cell-centered offset.
214228
Args:
215229
laplacians: list of 1d laplacians
216230
grid: grid object
217-
axis: axis along which to impose dirichlet bc.
231+
dim: axis along which to impose dirichlet bc.
218232
side: lower or upper side to assign boundary to.
219233
220234
Returns:
@@ -223,52 +237,50 @@ def _laplacian_boundary_dirichlet_cell_centered(
223237
TODO:
224238
[ ]: this function is not implemented in the original Jax-CFD code.
225239
"""
226-
# This function assumes homogeneous boundary, in which case if the offset
227-
# is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
228-
# 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
240+
229241
if side == "lower":
230-
laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis] ** 2
242+
laplacians[dim][0, 0] = laplacians[dim][0, 0] - 1 / grid.step[dim] ** 2
231243
else:
232-
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis] ** 2
244+
laplacians[dim][-1, -1] = laplacians[dim][-1, -1] - 1 / grid.step[dim] ** 2
233245
# deletes corner dependencies on the "looped-around" part.
234246
# this should be done irrespective of which side, since one boundary cannot
235247
# be periodic while the other is.
236-
laplacians[axis][0, -1] = 0.0
237-
laplacians[axis][-1, 0] = 0.0
238-
return laplacians
248+
laplacians[dim][0, -1] = 0.0
249+
laplacians[dim][-1, 0] = 0.0
250+
return
239251

240252

241253
def _laplacian_boundary_neumann_cell_centered(
242-
laplacians: List[Any], grid: grids.Grid, axis: int, side: str
254+
laplacians: ArrayVector, grid: grids.Grid, dim: int, side: str
243255
) -> None:
244256
"""Converts 1d laplacian matrix to satisfy neumann homogeneous bc.
245257
246258
This function assumes the RHS will have a cell-centered offset.
247259
Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
248-
code.
260+
code. For homogeneous Neumann BC (du/dn = 0), the ghost cell should equal the interior cell: u[ghost] = u[1]. The stencil becomes:
261+
[1, -2, 1] * [u[1], u[1], u[2]] = u[1] - 2*u[1] + u[2] = -u[1] + u[2]
262+
The original diagonal of Laplacian Lap[0, 0] is -2/h**2, we need -1/h**2,
263+
thus 1/h**2 is added to the diagonal, and the ghost cell dof is set to zero (Lap[0, -1]).
249264
250265
Args:
251266
laplacians: list of 1d laplacians
252267
grid: grid object
253-
axis: axis along which to impose dirichlet bc.
268+
dim: axis along which to impose dirichlet bc.
254269
side: which boundary side to convert to neumann homogeneous bc.
255270
256271
Returns:
257272
updated list of 1d laplacians.
258-
259-
TODO
260-
[ ]: this function is not implemented in the original Jax-CFD code.
261273
"""
262274
if side == "lower":
263-
laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis] ** 2
275+
laplacians[dim][0, 0] = laplacians[dim][0, 0] + 1 / grid.step[dim] ** 2
264276
else:
265-
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis] ** 2
277+
laplacians[dim][-1, -1] = laplacians[dim][-1, -1] + 1 / grid.step[dim] ** 2
266278
# deletes corner dependencies on the "looped-around" part.
267279
# this should be done irrespective of which side, since one boundary cannot
268280
# be periodic while the other is.
269-
laplacians[axis][0, -1] = 0.0
270-
laplacians[axis][-1, 0] = 0.0
271-
return laplacians
281+
laplacians[dim][0, -1] = 0.0
282+
laplacians[dim][-1, 0] = 0.0
283+
return
272284

273285

274286
def laplacian_matrix_w_boundaries(
@@ -277,6 +289,7 @@ def laplacian_matrix_w_boundaries(
277289
bc: grids.BoundaryConditions,
278290
laplacians: Optional[ArrayVector] = None,
279291
device: Optional[torch.device] = None,
292+
dtype: torch.dtype = torch.float32,
280293
sparse: bool = False,
281294
) -> ArrayVector:
282295
"""Returns 1d laplacians that satisfy boundary conditions bc on grid.
@@ -323,11 +336,13 @@ def laplacian_matrix_w_boundaries(
323336
raise NotImplementedError(
324337
"edge-aligned Neumann boundaries are not implemented."
325338
)
326-
return list(lap.to(device) for lap in laplacians) if device else laplacians
339+
return list(lap.to(dtype).to(device) for lap in laplacians)
327340

328341

329342
def _linear_along_axis(c: GridVariable, offset: float, dim: int) -> GridVariable:
330-
"""Linear interpolation of `c` to `offset` along a single specified `axis`."""
343+
"""Linear interpolation of `c` to `offset` along a single specified `axis`.
344+
dim here is >= 0, the negative indexing for batched implementation is handled by grids.shift.
345+
"""
331346
offset_delta = offset - c.offset[dim]
332347

333348
# If offsets are the same, `c` is unchanged.
@@ -383,8 +398,8 @@ def linear(
383398
f"got {c.offset} and {offset}."
384399
)
385400
interpolated = c
386-
for a, o in enumerate(offset):
387-
interpolated = _linear_along_axis(interpolated, offset=o, dim=a)
401+
for dim, o in enumerate(offset):
402+
interpolated = _linear_along_axis(interpolated, offset=o, dim=dim)
388403
return interpolated
389404

390405

@@ -405,15 +420,15 @@ def gradient_tensor(v):
405420
if not isinstance(v, GridVariable):
406421
return GridTensor(torch.stack([gradient_tensor(u) for u in v], dim=-1))
407422
grad = []
408-
for axis in range(v.grid.ndim):
409-
offset = v.offset[axis]
423+
for dim in range(-v.grid.ndim, 0):
424+
offset = v.offset[dim]
410425
if offset == 0:
411-
derivative = forward_difference(v, axis)
426+
derivative = forward_difference(v, dim)
412427
elif offset == 1:
413-
derivative = backward_difference(v, axis)
428+
derivative = backward_difference(v, dim)
414429
elif offset == 0.5:
415430
v_centered = linear(v, v.grid.cell_center)
416-
derivative = central_difference(v_centered, axis)
431+
derivative = central_difference(v_centered, dim)
417432
else:
418433
raise ValueError(f"expected offset values in {{0, 0.5, 1}}, got {offset}")
419434
grad.append(derivative)
@@ -427,4 +442,4 @@ def curl_2d(v: GridVariableVector) -> GridVariable:
427442
grid = grids.consistent_grid_arrays(*v)
428443
if grid.ndim != 2:
429444
raise ValueError(f"Grid dimensionality is not 2: {grid.ndim}")
430-
return forward_difference(v[1], dim=0) - forward_difference(v[0], dim=1)
445+
return forward_difference(v[1], dim=-2) - forward_difference(v[0], dim=-1)

torch_cfd/forcings.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
Grid = grids.Grid
2727
GridVariable = grids.GridVariable
28+
GridVariableVector = grids.GridVariableVector
2829

2930

3031
def forcing_eval(eval_func):
@@ -79,7 +80,7 @@ class ForcingFn(nn.Module):
7980
def __init__(
8081
self,
8182
grid: Grid,
82-
scale: float = 1,
83+
scale: float = 1.0,
8384
wave_number: int = 1,
8485
diam: float = 1.0,
8586
swap_xy: bool = False,
@@ -100,20 +101,20 @@ def __init__(
100101

101102
@forcing_eval
102103
def velocity_eval(
103-
grid: Grid, velocity: Optional[Tuple[torch.Tensor, torch.Tensor]]
104-
) -> Tuple[torch.Tensor, torch.Tensor]:
104+
self, grid: Grid, velocity: Optional[Tuple[GridVariable, GridVariable]]
105+
) -> GridVariableVector:
105106
raise NotImplementedError
106107

107108
@forcing_eval
108-
def vorticity_eval(grid: Grid, vorticity: Optional[torch.Tensor]) -> torch.Tensor:
109+
def vorticity_eval(self, grid: Grid, vorticity: Optional[torch.Tensor]) -> GridVariable:
109110
raise NotImplementedError
110111

111112
def forward(
112113
self,
113114
grid: Optional[Union[Grid, Tuple[Grid, Grid]]] = None,
114115
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
115116
vorticity: Optional[torch.Tensor] = None,
116-
) -> Tuple[torch.Tensor, torch.Tensor]:
117+
) -> Union[GridVariable, GridVariableVector]:
117118
if not self.vorticity:
118119
return self.velocity_eval(grid, velocity)
119120
else:
@@ -166,7 +167,7 @@ def velocity_eval(
166167
self,
167168
grid: Optional[Grid],
168169
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
169-
) -> Tuple[torch.Tensor, torch.Tensor]:
170+
) -> GridVariableVector:
170171
offsets = self.offsets
171172
grid = self.grid if grid is None else grid
172173
domain_factor = 2 * torch.pi / self.diam
@@ -187,13 +188,13 @@ def velocity_eval(
187188
grid,
188189
)
189190
v = GridVariable(torch.zeros_like(u.data), (1 / 2, 1), grid)
190-
return tuple((u, v))
191+
return GridVariableVector(tuple((u, v)))
191192

192193
def vorticity_eval(
193194
self,
194195
grid: Optional[Grid],
195196
vorticity: Optional[torch.Tensor] = None,
196-
) -> torch.Tensor:
197+
) -> GridVariable:
197198
offsets = self.offsets
198199
grid = self.grid if grid is None else grid
199200
domain_factor = 2 * torch.pi / self.diam
@@ -243,9 +244,9 @@ class SimpleSolenoidalForcing(ForcingFn):
243244

244245
def __init__(
245246
self,
246-
scale=1,
247+
scale=1.0,
247248
diam=1.0,
248-
k=1.0,
249+
wave_number=1,
249250
offsets=((0, 0), (0, 0)),
250251
vorticity=True,
251252
*args,
@@ -255,7 +256,7 @@ def __init__(
255256
*args,
256257
scale=scale,
257258
diam=diam,
258-
wave_number=k,
259+
wave_number=wave_number,
259260
offsets=offsets,
260261
vorticity=vorticity,
261262
**kwargs,
@@ -273,7 +274,7 @@ def velocity_eval(
273274
self,
274275
grid: Optional[Grid],
275276
velocity: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
276-
) -> Tuple[torch.Tensor, torch.Tensor]:
277+
) -> GridVariableVector:
277278
offsets = self.offsets
278279
grid = self.grid if grid is None else grid
279280
domain_factor = 2 * torch.pi / self.diam
@@ -292,7 +293,7 @@ def velocity_eval(
292293
rot = self.potential(x, y, scale, k)
293294
u = GridVariable(rot, offsets[0], grid)
294295
v = GridVariable(-rot, (1 / 2, 1), grid)
295-
return tuple((u, v))
296+
return GridVariableVector(tuple((u, v)))
296297

297298
def vorticity_eval(
298299
self,
@@ -339,7 +340,7 @@ def __init__(
339340
self,
340341
scale=0.1,
341342
diam=1.0,
342-
k=1.0,
343+
wave_number=1,
343344
offsets=((0, 0), (0, 0)),
344345
*args,
345346
**kwargs,
@@ -348,7 +349,7 @@ def __init__(
348349
*args,
349350
scale=scale,
350351
diam=diam,
351-
k=k,
352+
wave_number=wave_number,
352353
offsets=offsets,
353354
**kwargs,
354355
)

0 commit comments

Comments
 (0)