Skip to content

Commit 4063d25

Browse files
committed
stencil operations reset bc to None
1 parent 8b8a469 commit 4063d25

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

torch_cfd/boundaries.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,21 @@ class Padding:
4242
class ConstantBoundaryConditions(grids.BoundaryConditions):
4343
"""Boundary conditions for a PDE variable that are constant in space and time.
4444
45+
Attributes:
46+
types: a tuple of tuples, where `types[i]` is a tuple specifying the lower and upper BC types for dimension `i`. The types can be one of the following:
47+
BCType.PERIODIC, BCType.DIRICHLET, BCType.NEUMANN.
48+
values: a tuple of tuples, where `values[i]` is a tuple specifying the lower and upper boundary values for dimension `i`. If None, the boundary condition is homogeneous (zero).
49+
4550
Example usage:
4651
grid = Grid((10, 10))
4752
bc = ConstantBoundaryConditions(((BCType.PERIODIC, BCType.PERIODIC),
4853
(BCType.DIRICHLET, BCType.DIRICHLET)),
4954
((0.0, 10.0),(1.0, 0.0)))
55+
# in dimension 0 is periodic, (0, 10) on left and right (un-used)
56+
# in dimension 1 is dirichlet, (1, 0) on bottom and top.
5057
v = GridVariable(torch.zeros((10, 10)), offset=(0.5, 0.5), grid, bc)
58+
# v.offset is (0.5, 0.5) which is the cell center, so the boundary conditions have no effect in this case
5159
52-
Attributes:
53-
types: `types[i]` is a tuple specifying the lower and upper BC types for
54-
dimension `i`.
5560
"""
5661

5762
types: Tuple[Tuple[str, str], ...]
@@ -90,7 +95,7 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
9095
Neumann edge aligned boundary is not defined.
9196
9297
Args:
93-
u: torch.Tensor that should contain interior data
98+
u: GridVariable that should contain interior data
9499
dim: axis along which to check
95100
96101
Returns:
@@ -101,8 +106,11 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
101106
size_diff += 1
102107
if self.types[dim][1] == BCType.DIRICHLET and math.isclose(u.offset[dim], 1):
103108
size_diff += 1
104-
if self.types[dim][0] == BCType.NEUMANN and math.isclose(u.offset[dim] % 1, 0):
105-
raise NotImplementedError("Edge-aligned neumann BC are not implemented.")
109+
if (self.types[dim][0] == BCType.NEUMANN and math.isclose(u.offset[dim], 1)) or (self.types[dim][1] == BCType.NEUMANN and math.isclose(u.offset[dim], 0)):
110+
"""
111+
if lower (or left for dim 0) is Neumann, and the offset is 1 (the variable is on the right edge of a cell), the Neumann bc is not defined; vice versa for upper (or right for dim 0) Neumann bc with offset 0 (the variable is on the left edge of a cell).
112+
"""
113+
raise ValueError("Variable not aligned with Neumann BC")
106114
if size_diff < 0:
107115
raise ValueError(
108116
"the GridVariable does not contain all interior grid values."
@@ -124,11 +132,17 @@ def values(
124132
"""
125133
if None in self._values[dim]:
126134
return (None, None)
127-
bc = tuple(
128-
torch.full(grid.shape[:dim] + grid.shape[dim + 1 :], self._values[dim][-i])
129-
for i in [0, 1]
130-
)
131-
return bc
135+
136+
bc = []
137+
for i in [0, 1]:
138+
value = self._values[dim][-i]
139+
if value is None:
140+
bc.append(None)
141+
else:
142+
bc.append(torch.full(grid.shape[:dim] + grid.shape[dim + 1 :], value))
143+
144+
return tuple(bc)
145+
132146

133147
def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both"):
134148
"""Trims padding from a GridVariable along axis and returns the array interior.
@@ -140,7 +154,7 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
140154
If 'left', the left side.
141155
142156
Returns:
143-
Trimmed array, shrunk along the indicated axis side.
157+
Trimmed array, shrunk along the indicated axis side. bc is updated to None
144158
"""
145159
positive_trim = 0
146160
negative_trim = 0
@@ -187,19 +201,19 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
187201
def trim_boundary(self, u: GridVariable) -> GridVariable:
188202
"""Returns GridVariable without the grid points on the boundary.
189203
190-
Some grid points of GridVariable might coincide with boundary. This trims those
191-
values. If the array was padded beforehand, removes the padding.
204+
Some grid points of GridVariable might coincide with boundary. This trims those values.
205+
If the array was padded beforehand, removes the padding.
192206
193207
Args:
194208
u: a `GridVariable` object.
195209
196210
Returns:
197211
A GridVariable shrunk along certain dimensions.
198212
"""
199-
for axis in range(-u.grid.ndim, 0):
200-
_ = self._is_aligned(u, axis)
201-
u, _ = self._trim_padding(u, axis)
202-
return u
213+
for dim in range(-u.grid.ndim, 0):
214+
_ = self._is_aligned(u, dim)
215+
u, _ = self._trim_padding(u, dim)
216+
return GridVariable(u.data, u.offset, u.grid)
203217

204218
def pad_and_impose_bc(
205219
self,
@@ -538,6 +552,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
538552
flux_bc_values = []
539553

540554
# Handle both homogeneous and non-homogeneous boundary conditions
555+
# cannot handle mixed boundary conditions yet.
541556
if isinstance(u_bc, HomogeneousBoundaryConditions):
542557
u_values = tuple((0.0, 0.0) for _ in range(u_bc.ndim))
543558
elif isinstance(u_bc, ConstantBoundaryConditions):
@@ -557,7 +572,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
557572
)
558573

559574
for axis in range(c_bc.ndim):
560-
if u_bc.types[axis][0] == "periodic":
575+
if u_bc.types[axis][0] == BCType.PERIODIC:
561576
flux_bc_types.append((BCType.PERIODIC, BCType.PERIODIC))
562577
flux_bc_values.append((None, None))
563578
elif flux_direction != axis:

0 commit comments

Comments
 (0)