Skip to content

Commit cf1c857

Browse files
committed
changed the behavior of pad_and_impose_bc to include ghost cells
1 parent 4063d25 commit cf1c857

File tree

4 files changed

+708
-218
lines changed

4 files changed

+708
-218
lines changed

torch_cfd/boundaries.py

Lines changed: 168 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import dataclasses
2020
import math
21-
from typing import Optional, Sequence, Tuple
21+
from typing import Optional, Sequence, Tuple, Union
2222

2323
import torch
2424

@@ -31,15 +31,11 @@
3131

3232

3333
BCType = grids.BCType()
34+
Padding = grids.Padding()
3435

3536

36-
class Padding:
37-
MIRROR = "reflect"
38-
EXTEND = "replicate"
39-
40-
41-
@dataclasses.dataclass(init=False, frozen=True)
42-
class ConstantBoundaryConditions(grids.BoundaryConditions):
37+
@dataclasses.dataclass(init=False, frozen=True, repr=False)
38+
class ConstantBoundaryConditions(BoundaryConditions):
4339
"""Boundary conditions for a PDE variable that are constant in space and time.
4440
4541
Attributes:
@@ -59,8 +55,8 @@ class ConstantBoundaryConditions(grids.BoundaryConditions):
5955
6056
"""
6157

62-
types: Tuple[Tuple[str, str], ...]
63-
_values: Tuple[Tuple[Optional[float], Optional[float]], ...]
58+
_types: Tuple[Tuple[str, str], ...]
59+
bc_values: Tuple[Tuple[Optional[float], Optional[float]], ...]
6460

6561
def __init__(
6662
self,
@@ -69,10 +65,62 @@ def __init__(
6965
):
7066
types = tuple(types)
7167
values = tuple(values)
72-
object.__setattr__(self, "types", types)
73-
object.__setattr__(self, "_values", values)
68+
object.__setattr__(self, "_types", types)
69+
object.__setattr__(self, "bc_values", values)
7470
object.__setattr__(self, "ndim", len(types))
7571

72+
@property
73+
def types(self) -> Tuple[Tuple[str, str], ...]:
74+
"""Returns the boundary condition types."""
75+
return self._types
76+
77+
@types.setter
78+
def types(self, bc_types: Sequence[Tuple[str, str]]) -> None:
79+
"""Sets the boundary condition types and updates ndim accordingly."""
80+
bc_types = tuple(bc_types)
81+
assert self.ndim == len(
82+
bc_types
83+
), f"Number of dimensions {self.ndim} does not match the number of types {bc_types}."
84+
object.__setattr__(self, "_types", bc_types)
85+
86+
def __repr__(self) -> str:
87+
try:
88+
lines = [f"{self.__class__.__name__}({self.ndim}D):"]
89+
90+
for dim in range(self.ndim):
91+
lower_type, upper_type = self.types[dim]
92+
lower_val, upper_val = self.bc_values[dim]
93+
94+
# Format values
95+
lower_val_str = "None" if lower_val is None else f"{lower_val}"
96+
upper_val_str = "None" if upper_val is None else f"{upper_val}"
97+
98+
lines.append(
99+
f" dim {dim}: [{lower_type}({lower_val_str}), {upper_type}({upper_val_str})]"
100+
)
101+
102+
return "\n".join(lines)
103+
except Exception as e:
104+
return f"{self.__class__.__name__} not initialized: {e}"
105+
106+
def clone(
107+
self,
108+
types: Optional[Sequence[Tuple[str, str]]] = None,
109+
values: Optional[Sequence[Tuple[Optional[float], Optional[float]]]] = None,
110+
) -> BoundaryConditions:
111+
"""Creates a copy of this boundary condition, optionally with modified parameters.
112+
113+
Args:
114+
types: New boundary condition types. If None, uses current types.
115+
values: New boundary condition values. If None, uses current values.
116+
117+
Returns:
118+
A new ConstantBoundaryConditions instance.
119+
"""
120+
new_types = types if types is not None else self.types
121+
new_values = values if values is not None else self.bc_values
122+
return ConstantBoundaryConditions(new_types, new_values)
123+
76124
def shift(
77125
self,
78126
u: GridVariable,
@@ -106,7 +154,9 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
106154
size_diff += 1
107155
if self.types[dim][1] == BCType.DIRICHLET and math.isclose(u.offset[dim], 1):
108156
size_diff += 1
109-
if (self.types[dim][0] == BCType.NEUMANN and math.isclose(u.offset[dim], 1)) or (self.types[dim][1] == BCType.NEUMANN and math.isclose(u.offset[dim], 0)):
157+
if (
158+
self.types[dim][0] == BCType.NEUMANN and math.isclose(u.offset[dim], 1)
159+
) or (self.types[dim][1] == BCType.NEUMANN and math.isclose(u.offset[dim], 0)):
110160
"""
111161
if lower (or left for dim 0) is Neumann, and the offset is 1 (the variable is on the right edge of a cell), the Neumann bc is not defined; vice versa for upper (or right for dim 0) Neumann bc with offset 0 (the variable is on the left edge of a cell).
112162
"""
@@ -117,6 +167,78 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
117167
)
118168
return True
119169

170+
def pad(
171+
self,
172+
u: GridVariable,
173+
width: Union[Tuple[int, int], int],
174+
dim: int,
175+
mode: Optional[str] = Padding.EXTEND,
176+
) -> GridVariable:
177+
"""Wrapper for grids.pad with a specific bc.
178+
179+
Args:
180+
u: a `GridVariable` object.
181+
width: number of elements to pad along axis. If width is an int, use
182+
negative value for lower boundary or positive value for upper boundary.
183+
If a tuple, pads with width[0] on the left and width[1] on the right.
184+
dim: axis to pad along.
185+
mode: type of padding to use in non-periodic case.
186+
Mirror mirrors the array values across the boundary.
187+
Extend extends the last well-defined array value past the boundary.
188+
189+
Returns:
190+
Padded array, elongated along the indicated axis.
191+
the original u.bc will be replaced with self.
192+
"""
193+
_ = self._is_aligned(u, dim)
194+
if isinstance(width, tuple) and (width[0] > 0 and width[1] > 0):
195+
need_trimming = "both"
196+
elif (isinstance(width, tuple) and (width[0] > 0 and width[1] == 0)) or (
197+
isinstance(width, int) and width < 0
198+
):
199+
need_trimming = "left"
200+
elif (isinstance(width, tuple) and (width[0] == 0 and width[1] > 0)) or (
201+
isinstance(width, int) and width > 0
202+
):
203+
need_trimming = "right"
204+
else:
205+
need_trimming = "none"
206+
207+
u, trimmed_padding = self._trim_padding(u, dim, need_trimming)
208+
209+
if isinstance(width, int):
210+
if width < 0:
211+
width -= trimmed_padding[0]
212+
if width > 0:
213+
width += trimmed_padding[1]
214+
elif isinstance(width, tuple):
215+
width = (width[0] + trimmed_padding[0], width[1] + trimmed_padding[1])
216+
217+
u = grids.pad(u, width, dim, self, mode=mode)
218+
return u
219+
220+
def pad_all(
221+
self,
222+
u: GridVariable,
223+
width: Tuple[Tuple[int, int], ...],
224+
mode: Optional[str] = Padding.EXTEND,
225+
) -> GridVariable:
226+
"""Pads along all axes with pad width specified by width tuple.
227+
228+
Args:
229+
u: a `GridArray` object.
230+
width: Tuple of padding width for each side for each axis.
231+
mode: type of padding to use in non-periodic case.
232+
Mirror mirrors the array values across the boundary.
233+
Extend extends the last well-defined array value past the boundary.
234+
235+
Returns:
236+
Padded array, elongated along all axes.
237+
"""
238+
for dim in range(-u.grid.ndim, 0):
239+
u = self.pad(u, width[dim], dim, mode=mode)
240+
return u
241+
120242
def values(
121243
self, dim: int, grid: Grid
122244
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
@@ -130,21 +252,22 @@ def values(
130252
A tuple of arrays of grid.ndim - 1 dimensions that specify values on the
131253
boundary. In case of periodic boundaries, returns a tuple(None,None).
132254
"""
133-
if None in self._values[dim]:
255+
if None in self.bc_values[dim]:
134256
return (None, None)
135-
257+
136258
bc = []
137259
for i in [0, 1]:
138-
value = self._values[dim][-i]
260+
value = self.bc_values[dim][-i]
139261
if value is None:
140262
bc.append(None)
141263
else:
142264
bc.append(torch.full(grid.shape[:dim] + grid.shape[dim + 1 :], value))
143-
265+
144266
return tuple(bc)
145-
146267

147-
def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both"):
268+
def _trim_padding(
269+
self, u: GridVariable, dim: int = -1, trim_side: str = "both"
270+
) -> Tuple[GridVariable, Tuple[int, int]]:
148271
"""Trims padding from a GridVariable along axis and returns the array interior.
149272
150273
Args:
@@ -160,6 +283,9 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
160283
negative_trim = 0
161284
padding = (0, 0)
162285

286+
if trim_side not in ("both", "left", "right"):
287+
return u.array, padding
288+
163289
if u.shape[dim] >= u.grid.shape[dim]:
164290
# number of cells that were padded on the left
165291
negative_trim = 0
@@ -201,7 +327,7 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
201327
def trim_boundary(self, u: GridVariable) -> GridVariable:
202328
"""Returns GridVariable without the grid points on the boundary.
203329
204-
Some grid points of GridVariable might coincide with boundary. This trims those values.
330+
Some grid points of GridVariable might coincide with boundary. This trims those values.
205331
If the array was padded beforehand, removes the padding.
206332
207333
Args:
@@ -223,33 +349,29 @@ def pad_and_impose_bc(
223349
) -> GridVariable:
224350
"""Returns GridVariable with correct boundary values.
225351
226-
Some grid points of GridVariable might coincide with boundary. This ensures
227-
that the GridVariable.array agrees with GridVariable.bc.
352+
Some grid points of GridVariable might coincide with boundary, thus this function is only used with the trimmed GridVariable.
228353
Args:
229-
u: a `GridVariable` object that specifies only scalar values on the internal
230-
nodes.
231-
offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the
232-
function is given just an interior array in dirichlet case, it can pad
233-
to both 0 offset and 1 offset.
234-
mode: type of padding to use in non-periodic case.
235-
Mirror mirrors the flow across the boundary.
236-
Extend extends the last well-defined value past the boundary.
354+
- u: a `GridVariable` object that specifies only scalar values on the internal nodes.
355+
- offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the function is given just an interior array in dirichlet case, it can pad to both 0 offset and 1 offset.
356+
- mode: type of padding to use in non-periodic case. Mirror mirrors the flow across the boundary. Extend extends the last well-defined value past the boundary.
237357
238358
Returns:
239359
A GridVariable that has correct boundary values.
240360
"""
361+
assert u.bc is None, "u must be trimmed before padding and imposing bc."
241362
if offset_to_pad_to is None:
242363
offset_to_pad_to = u.offset
243-
for axis in range(-u.grid.ndim, 0):
244-
_ = self._is_aligned(u, axis)
245-
if self.types[axis][0] == BCType.DIRICHLET and math.isclose(
246-
u.offset[axis], 1.0
247-
):
248-
if math.isclose(offset_to_pad_to[axis], 1.0):
249-
u = grids.pad(u, 1, axis, self)
250-
elif math.isclose(offset_to_pad_to[axis], 0.0):
251-
u = grids.pad(u, -1, axis, self)
252-
return GridVariable(u.data, u.offset, u.grid, self)
364+
for dim in range(-u.grid.ndim, 0):
365+
_ = self._is_aligned(u, dim)
366+
if self.types[dim][0] != BCType.PERIODIC:
367+
# if the offset is either 0 or 1, u is aligned with the boundary and is defined on cell edges on one side of the boundary, if trim_boundary is called before this function.
368+
# u needs to be padded on both sides
369+
# if the offset is 0.5, one ghost cell is needed on each side.
370+
# it will be taken care by grids.pad function automatically.
371+
u = grids.pad(u, (1, 1), dim, self)
372+
elif self.types[dim][0] == BCType.PERIODIC:
373+
return GridVariable(u.data, u.offset, u.grid, self)
374+
return u
253375

254376
def impose_bc(self, u: GridVariable) -> GridVariable:
255377
"""Returns GridVariable with correct boundary condition.
@@ -260,8 +382,7 @@ def impose_bc(self, u: GridVariable) -> GridVariable:
260382
u: a `GridVariable` object.
261383
262384
Returns:
263-
A GridVariable that has correct boundary values and is restricted to the
264-
domain.
385+
A GridVariable that has correct boundary values with ghost cells added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann.
265386
"""
266387
offset = u.offset
267388
u = self.trim_boundary(u)
@@ -317,7 +438,7 @@ def periodic_boundary_conditions(ndim: int) -> BoundaryConditions:
317438

318439
def dirichlet_boundary_conditions(
319440
ndim: int,
320-
bc_vals: Optional[Sequence[Tuple[float, float]]] = None,
441+
bc_values: Optional[Sequence[Tuple[float, float]]] = None,
321442
) -> BoundaryConditions:
322443
"""Returns Dirichelt BCs for a variable with `ndim` spatial dimension.
323444
@@ -329,13 +450,13 @@ def dirichlet_boundary_conditions(
329450
Returns:
330451
BoundaryCondition subclass.
331452
"""
332-
if not bc_vals:
453+
if not bc_values:
333454
return HomogeneousBoundaryConditions(
334455
((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim
335456
)
336457
else:
337458
return ConstantBoundaryConditions(
338-
((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim, bc_vals
459+
((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim, bc_values
339460
)
340461

341462

@@ -556,7 +677,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
556677
if isinstance(u_bc, HomogeneousBoundaryConditions):
557678
u_values = tuple((0.0, 0.0) for _ in range(u_bc.ndim))
558679
elif isinstance(u_bc, ConstantBoundaryConditions):
559-
u_values = u_bc._values
680+
u_values = u_bc.bc_values
560681
else:
561682
raise NotImplementedError(
562683
f"Flux boundary condition is not implemented for velocity with {type(u_bc)}"
@@ -565,7 +686,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
565686
if isinstance(c_bc, HomogeneousBoundaryConditions):
566687
c_values = tuple((0.0, 0.0) for _ in range(c_bc.ndim))
567688
elif isinstance(c_bc, ConstantBoundaryConditions):
568-
c_values = c_bc._values
689+
c_values = c_bc.bc_values
569690
else:
570691
raise NotImplementedError(
571692
f"Flux boundary condition is not implemented for scalar with {type(c_bc)}"

0 commit comments

Comments
 (0)