Skip to content

Commit 56c4118

Browse files
committed
updated padding behavior to have a fallback
1 parent cf1c857 commit 56c4118

File tree

4 files changed

+67
-47
lines changed

4 files changed

+67
-47
lines changed

torch_cfd/boundaries.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,15 @@ def pad_and_impose_bc(
345345
self,
346346
u: GridVariable,
347347
offset_to_pad_to: Optional[Tuple[float, ...]] = None,
348-
mode: Optional[str] = "extend",
348+
mode: Optional[str] = "",
349349
) -> GridVariable:
350350
"""Returns GridVariable with correct boundary values.
351351
352352
Some grid points of GridVariable might coincide with boundary, thus this function is only used with the trimmed GridVariable.
353353
Args:
354354
- u: a `GridVariable` object that specifies only scalar values on the internal nodes.
355355
- offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the function is given just an interior array in dirichlet case, it can pad to both 0 offset and 1 offset.
356-
- mode: type of padding to use in non-periodic case. Mirror mirrors the flow across the boundary. Extend extends the last well-defined value past the boundary.
356+
- mode: type of padding to use in non-periodic case. Mirror mirrors the flow across the boundary. Extend extends the last well-defined value past the boundary. None means no ghost cell padding.
357357
358358
Returns:
359359
A GridVariable that has correct boundary values.
@@ -364,16 +364,27 @@ def pad_and_impose_bc(
364364
for dim in range(-u.grid.ndim, 0):
365365
_ = self._is_aligned(u, dim)
366366
if self.types[dim][0] != BCType.PERIODIC:
367-
# if the offset is either 0 or 1, u is aligned with the boundary and is defined on cell edges on one side of the boundary, if trim_boundary is called before this function.
368-
# u needs to be padded on both sides
369-
# if the offset is 0.5, one ghost cell is needed on each side.
370-
# it will be taken care by grids.pad function automatically.
371-
u = grids.pad(u, (1, 1), dim, self)
372-
elif self.types[dim][0] == BCType.PERIODIC:
373-
return GridVariable(u.data, u.offset, u.grid, self)
374-
return u
367+
if mode:
368+
# if the offset is either 0 or 1, u is aligned with the boundary and is defined on cell edges on one side of the boundary, if trim_boundary is called before this function.
369+
# u needs to be padded on both sides
370+
# if the offset is 0.5, one ghost cell is needed on each side.
371+
# it will be taken care by grids.pad function automatically.
372+
u = grids.pad(u, (1, 1), dim, self, mode=mode)
373+
elif self.types[dim][0] == BCType.DIRICHLET and not mode:
374+
if math.isclose(offset_to_pad_to[dim], 1.0):
375+
u = grids.pad(u, 1, dim, self)
376+
elif math.isclose(offset_to_pad_to[dim], 0.0):
377+
u = grids.pad(u, -1, dim, self)
378+
elif self.types[dim][0] == BCType.NEUMANN and not mode:
379+
if not math.isclose(offset_to_pad_to[dim], 0.5):
380+
raise ValueError("Neumann bc is not defined on edges.")
381+
else:
382+
raise NotImplementedError(
383+
f"Padding for {self.types[dim][0]} boundary conditions is not implemented."
384+
)
385+
return GridVariable(u.data, u.offset, u.grid, self)
375386

376-
def impose_bc(self, u: GridVariable) -> GridVariable:
387+
def impose_bc(self, u: GridVariable, mode: str="") -> GridVariable:
377388
"""Returns GridVariable with correct boundary condition.
378389
379390
Some grid points of GridVariable might coincide with boundary. This ensures
@@ -382,11 +393,11 @@ def impose_bc(self, u: GridVariable) -> GridVariable:
382393
u: a `GridVariable` object.
383394
384395
Returns:
385-
A GridVariable that has correct boundary values with ghost cells added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann.
396+
A GridVariable that has correct boundary values. If ghost_cell == True, then ghost cells are added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann.
386397
"""
387398
offset = u.offset
388399
u = self.trim_boundary(u)
389-
u = self.pad_and_impose_bc(u, offset)
400+
u = self.pad_and_impose_bc(u, offset, mode)
390401
return u
391402

392403

torch_cfd/grids.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@
3636
T = TypeVar("T") # for GridVariable vector
3737

3838

39+
class BCType:
40+
PERIODIC = "periodic"
41+
DIRICHLET = "dirichlet"
42+
NEUMANN = "neumann"
43+
ROBIN = "robin"
44+
NONE = None
45+
46+
47+
class Padding:
48+
MIRROR = "reflect"
49+
EXTEND = "replicate"
50+
SYMMETRIC = "symmetric"
51+
NONE = ""
52+
53+
3954
@dataclasses.dataclass(init=False, frozen=True)
4055
class Grid:
4156
"""
@@ -112,7 +127,7 @@ def __init__(
112127
def __repr__(self) -> str:
113128
lines = [f"Grid({self.ndim}D):"]
114129
lines.append(f" shape: {self.shape}")
115-
130+
116131
for i in range(self.ndim):
117132
lower, upper = self.domain[i]
118133
step = self.step[i]
@@ -241,20 +256,6 @@ def eval_on_mesh(
241256
return GridVariable(fn(*self.mesh(offset)), offset, self)
242257

243258

244-
class BCType:
245-
PERIODIC = "periodic"
246-
DIRICHLET = "dirichlet"
247-
NEUMANN = "neumann"
248-
ROBIN = "robin"
249-
NONE = None
250-
251-
252-
class Padding:
253-
MIRROR = "reflect"
254-
EXTEND = "replicate"
255-
SYMMETRIC = "symmetric"
256-
257-
258259
@dataclasses.dataclass(init=False, frozen=True)
259260
class BoundaryConditions:
260261
"""Base class for boundary conditions on a PDE variable.
@@ -322,12 +323,13 @@ def trim_boundary(
322323
def impose_bc(
323324
self,
324325
u: GridVariable,
326+
mode: Optional[str] = ""
325327
) -> GridVariable:
326328
"""Impose boundary conditions on the grid variable."""
327329
raise NotImplementedError(
328330
"impose_bc() not implemented in BoundaryConditions base class."
329331
)
330-
332+
331333
def pad_and_impose_bc(
332334
self,
333335
u: GridVariable,
@@ -778,15 +780,15 @@ def interior(self) -> GridVariable:
778780
interior_grid = self._interior_grid()
779781
return GridVariable(interior_array, self.offset, interior_grid)
780782

781-
def impose_bc(self) -> GridVariable:
783+
def impose_bc(self, mode: str="") -> GridVariable:
782784
"""Returns the GridVariable with edge BC enforced, if applicable.
783785
784786
For GridVariables having nonperiodic BC and offset 0 or 1, there are values
785787
in the array data that are dependent on the boundary condition.
786788
impose_bc() changes these boundary values to match the prescribed BC.
787789
"""
788790
assert self.bc is not None, "Boundary conditions must be set to impose BC."
789-
return self.bc.impose_bc(self)
791+
return self.bc.impose_bc(self, mode)
790792

791793
def enforce_edge_bc(self, *args) -> GridVariable:
792794
"""Returns the GridVariable with edge BC enforced, if applicable.
@@ -1030,9 +1032,11 @@ def pad(
10301032
assert not (
10311033
u.bc is None and bc is None and bc_types is None
10321034
), "u.bc, bc, and bc_types cannot be None at the same time"
1033-
assert mode in [Padding.MIRROR, Padding.EXTEND, Padding.SYMMETRIC], (
1034-
f"Padding mode must be one of ['{Padding.MIRROR}', '{Padding.EXTEND}', '{Padding.SYMMETRIC}'], got '{mode}'"
1035-
)
1035+
assert mode in [
1036+
Padding.MIRROR,
1037+
Padding.EXTEND,
1038+
Padding.SYMMETRIC,
1039+
], f"Padding mode must be one of ['{Padding.MIRROR}', '{Padding.EXTEND}', '{Padding.SYMMETRIC}'], got '{mode}'"
10361040
bc = bc if bc is not None else u.bc
10371041
bc_types = bc.types[dim] if bc_types is None else bc_types
10381042
values = bc.bc_values if values is None else values
@@ -1112,9 +1116,7 @@ def pad(
11121116
return GridVariable(data, tuple(new_offset), u.grid, bc)
11131117
elif mode == Padding.MIRROR:
11141118
bc_padding = [(0, 0)] * u.grid.ndim
1115-
bc_padding[dim] = tuple(
1116-
1 if pad > 0 else 0 for pad in padding
1117-
)
1119+
bc_padding[dim] = tuple(1 if pad > 0 else 0 for pad in padding)
11181120
# subtract the padded cell
11191121
full_padding_past_bc = [(0, 0)] * u.grid.ndim
11201122
full_padding_past_bc[dim] = tuple(
@@ -1125,13 +1127,17 @@ def pad(
11251127
u.data, bc_padding, mode="constant", constant_values=(0, 0)
11261128
)
11271129
padding_values = list(values)
1128-
padding_values[dim] = tuple([pad / 2 for pad in padding_values[dim]])
1130+
padding_values[dim] = tuple(
1131+
[pad / 2 for pad in padding_values[dim]]
1132+
)
11291133
data = 2 * expand_dims_pad(
11301134
u.data,
11311135
full_padding,
11321136
mode="constant",
11331137
constant_values=padding_values,
1134-
) - expand_dims_pad(expanded_data, full_padding_past_bc, mode="reflect")
1138+
) - expand_dims_pad(
1139+
expanded_data, full_padding_past_bc, mode="reflect"
1140+
)
11351141
return GridVariable(data, tuple(new_offset), u.grid, bc)
11361142
else:
11371143
raise ValueError(
@@ -1203,7 +1209,9 @@ def expand_dims_pad(
12031209
inputs: torch.Tensor,
12041210
pad: Sequence[Tuple[int, int]],
12051211
mode: str = "constant",
1206-
constant_values: Union[float, Tuple[float, float], Sequence[Tuple[float, float]]] = 0,
1212+
constant_values: Union[
1213+
float, Tuple[float, float], Sequence[Tuple[float, float]]
1214+
] = 0,
12071215
**kwargs,
12081216
) -> torch.Tensor:
12091217
"""

torch_cfd/tests/test_advection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ def test_burgers_analytical_dirichlet_convergence(
206206
offset,
207207
):
208208
def _step_func(v, dt, method):
209+
"""
210+
dt/2 is used because for Burgers equation
211+
the flux is u_t + (0.5*u^2)_x = 0
212+
"""
209213
dv_dt = method(c=v[0], v=v, dt=dt) / 2
210214
return (bc.impose_bc(v[0].data + dt * dv_dt),)
211215

@@ -225,10 +229,6 @@ def _velocity_implicit(grid, offset, u, t):
225229
advect = advect_van_leer(grid, offset)
226230

227231
for _ in range(num_steps):
228-
"""
229-
dt/2 is used because for Burgers equation
230-
the flux is u_t + (0.5*u^2)_x = 0
231-
"""
232232
v = _step_func(v, dt, method=advect)
233233

234234
expected = bc.impose_bc(

torch_cfd/tests/test_boundaries.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch_cfd import boundaries, grids, test_utils
2525

2626
BCType = boundaries.BCType
27+
Padding = boundaries.Padding
2728

2829
tensor = partial(torch.tensor, dtype=torch.float32)
2930

@@ -704,7 +705,7 @@ def test_pad_and_impose_bc_1d(
704705
grid = grids.Grid((grid_size,))
705706
array = grids.GridVariable(input_data, input_offset, grid)
706707
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
707-
actual = bc.pad_and_impose_bc(array, expected_offset)
708+
actual = bc.pad_and_impose_bc(array, expected_offset, mode=Padding.EXTEND)
708709
expected = grids.GridVariable(expected_data, expected_offset, grid)
709710
self.assertArrayEqual(actual, expected)
710711

@@ -765,7 +766,7 @@ def test_impose_bc_1d(
765766
grid = grids.Grid((grid_size,))
766767
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
767768
array = grids.GridVariable(input_data, input_offset, grid, bc)
768-
actual = bc.impose_bc(array)
769+
actual = bc.impose_bc(array, mode=Padding.EXTEND)
769770
expected = grids.GridVariable(expected_data, expected_offset, grid, bc)
770771
self.assertArrayEqual(actual, expected)
771772

@@ -818,7 +819,7 @@ def test_impose_bc_2d_constant_boundary(
818819
grid = grids.Grid(input_data.shape)
819820
bc = boundaries.dirichlet_boundary_conditions(grid.ndim, values)
820821
variable = grids.GridVariable(input_data, offset, grid, bc)
821-
variable = variable.impose_bc()
822+
variable = variable.impose_bc(mode=Padding.EXTEND)
822823
expected = grids.GridVariable(expected_data, expected_offset, grid)
823824
self.assertArrayEqual(variable, expected)
824825

0 commit comments

Comments
 (0)