Skip to content

Commit 8b8a469

Browse files
committed
refactored grids.py, added trim_bc tests
1 parent 6abe65a commit 8b8a469

File tree

2 files changed

+293
-80
lines changed

2 files changed

+293
-80
lines changed

torch_cfd/grids.py

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -646,32 +646,32 @@ def _interior_grid(self) -> Grid:
646646
grid = self.grid
647647
domain = list(grid.domain)
648648
shape = list(grid.shape)
649-
for axis in range(self.grid.ndim):
649+
for dim in range(self.grid.ndim):
650650
# nothing happens in periodic case
651-
if self.bc.types[axis][1] == "periodic":
651+
if self.bc.types[dim][1] == "periodic":
652652
continue
653653
# nothing happens if the offset is not 0.0 or 1.0
654654
# this will automatically set the grid to interior.
655-
if math.isclose(self.offset[axis], 1.0):
656-
shape[axis] -= 1
657-
domain[axis] = (domain[axis][0], domain[axis][1] - grid.step[axis])
658-
elif math.isclose(self.offset[axis], 0.0):
659-
shape[axis] -= 1
660-
domain[axis] = (domain[axis][0] + grid.step[axis], domain[axis][1])
655+
if math.isclose(self.offset[dim], 1.0):
656+
shape[dim] -= 1
657+
domain[dim] = (domain[dim][0], domain[dim][1] - grid.step[dim])
658+
elif math.isclose(self.offset[dim], 0.0):
659+
shape[dim] -= 1
660+
domain[dim] = (domain[dim][0] + grid.step[dim], domain[dim][1])
661661
return Grid(shape, domain=tuple(domain))
662662

663663
def _interior_array(self) -> torch.Tensor:
664664
"""Returns only the interior points of self.data."""
665665
data = self.data
666-
for axis in range(self.grid.ndim):
666+
for dim in range(self.grid.ndim):
667667
# nothing happens in periodic case
668-
if self.bc.types[axis][1] == "periodic":
668+
if self.bc.types[dim][1] == "periodic":
669669
continue
670670
# nothing happens if the offset is not 0.0 or 1.0
671-
if math.isclose(self.offset[axis], 1.0):
672-
data, _ = tensor_utils.split_along_axis(data, -1, axis)
673-
elif math.isclose(self.offset[axis], 0.0):
674-
_, data = tensor_utils.split_along_axis(data, 1, axis)
671+
if math.isclose(self.offset[dim], 1.0):
672+
data, _ = tensor_utils.split_along_axis(data, -1, dim)
673+
elif math.isclose(self.offset[dim], 0.0):
674+
_, data = tensor_utils.split_along_axis(data, 1, dim)
675675

676676
return data
677677

@@ -689,7 +689,10 @@ def interior(self) -> GridVariable:
689689
690690
In case of dirichlet with edge offset, the grid and array size is reduced,
691691
since one scalar lies exactly on the boundary. In all other cases,
692-
self.grid and self.array are returned.
692+
self.grid and self.data are returned.
693+
694+
Porting notes:
695+
This method actually does not check whether the boundary conditions are imposed or not, is purely determined by the offset.
693696
"""
694697
interior_array = self._interior_array()
695698
interior_grid = self._interior_grid()
@@ -703,22 +706,22 @@ def enforce_edge_bc(self, *args) -> GridVariable:
703706
enforce_edge_bc() changes these boundary values to match the prescribed BC.
704707
705708
Args:
706-
*args: any optional values passed into BoundaryConditions values method.
709+
*args: any optional values passed into BoundaryConditions values method.
707710
"""
708711
if self.grid.shape != self.data.shape:
709712
raise ValueError("Stored array and grid have mismatched sizes.")
710-
data = torch.as_tensor(self.data)
711-
for axis in range(self.grid.ndim):
712-
if "periodic" not in self.bc.types[axis]:
713-
values = self.bc.values(axis, self.grid, *args)
713+
data = torch.as_tensor(self.data).clone() # Clone to avoid modifying original
714+
for dim in range(self.grid.ndim):
715+
if "periodic" not in self.bc.types[dim]:
716+
values = self.bc.values(dim, self.grid, *args)
714717
for boundary_side in range(2):
715-
if torch.isclose(self.offset[axis], boundary_side):
718+
if math.isclose(self.offset[dim], boundary_side):
716719
# boundary data is set to match self.bc:
717720
all_slice = [
718721
slice(None, None, None),
719722
] * self.grid.ndim
720-
all_slice[axis] = -boundary_side
721-
data = data.at[tuple(all_slice)].set(values[boundary_side])
723+
all_slice[dim] = -boundary_side
724+
data[tuple(all_slice)] = values[boundary_side]
722725
return GridVariable(data, self.offset, self.grid, self.bc)
723726

724727

@@ -945,15 +948,15 @@ def pad(
945948
else:
946949
raise ValueError(
947950
"expected the new offset to be an edge or cell center, got "
948-
f"offset[axis]={u.offset[dim]}"
951+
f"offset[dim]={u.offset[dim]}"
949952
)
950953
elif bc_type == BCType.NEUMANN:
951954
if not (
952955
math.isclose(u.offset[dim] % 1, 0) or math.isclose(u.offset[dim] % 1, 0.5)
953956
):
954957
raise ValueError(
955958
"expected offset to be an edge or cell center, got "
956-
f"offset[axis]={u.offset[dim]}"
959+
f"offset[dim]={u.offset[dim]}"
957960
)
958961
else:
959962
# When the data is cell-centered, computes the backward difference.
@@ -1129,61 +1132,6 @@ def _constant_pad_tensor(
11291132
return result
11301133

11311134

1132-
# def _constant_pad(
1133-
# inputs: torch.Tensor,
1134-
# pad: Tuple[Tuple[int, int], ...],
1135-
# constant_values: Tuple[Tuple[float, float], ...],
1136-
# **kwargs,
1137-
# ) -> torch.Tensor:
1138-
# """
1139-
# Corrected padding function that supports different constant values for each side.
1140-
# Pads each dimension from first to last as per the user input, mapping correctly to
1141-
# PyTorch's last-to-first padding order.
1142-
# inputs was unsqueezed at dim 0 as a batch_dim, so actual data dims are shifted by +1
1143-
# """
1144-
# ndim = inputs.ndim - 1 #
1145-
# original_shape = list(inputs.shape)
1146-
# out_shape = [1] + [original_shape[i+1] + pad[i][0] + pad[i][1] for i in range(ndim)] # inputs was unsqueezed at dim 0, so actual data dims are shifted by +1
1147-
1148-
# output = torch.empty(out_shape, dtype=inputs.dtype, device=inputs.device)
1149-
1150-
# def get_vals(dim):
1151-
# if len(constant_values) > dim:
1152-
# vals = constant_values[dim]
1153-
# if isinstance(vals, (tuple, list)) and len(vals) == 2:
1154-
# return float(vals[0]), float(vals[1])
1155-
# else:
1156-
# val = float(vals)
1157-
# return val, val
1158-
# return 0.0, 0.0
1159-
1160-
# # Fill with zeros initially
1161-
# output.fill_(0.0)
1162-
1163-
# # Main region
1164-
# slices = (slice(None),) + tuple(slice(pad[i][0], pad[i][0] + original_shape[i+1]) for i in range(ndim))
1165-
# output[slices] = inputs
1166-
1167-
# # Apply left/right pad values per dim
1168-
# for i in range(ndim):
1169-
# lpad, rpad = pad[i]
1170-
# lval, rval = get_vals(i)
1171-
1172-
# if lpad > 0:
1173-
# left_slices = [slice(None)] * ndim
1174-
# left_slices[i] = slice(0, lpad)
1175-
# left_slides = (slice(None), ) + tuple(left_slices)
1176-
# output[left_slides] = lval
1177-
1178-
# if rpad > 0:
1179-
# right_slices = [slice(None)] * ndim
1180-
# right_slices[i] = slice(-rpad, None)
1181-
# right_slices = (slice(None), ) + tuple(right_slices)
1182-
# output[right_slices] = rval
1183-
1184-
# return output
1185-
1186-
11871135
def averaged_offset(*offsets: List[Tuple[float, ...]]) -> Tuple[float, ...]:
11881136
"""Returns the averaged offset of the given arrays."""
11891137
n = len(offsets)

0 commit comments

Comments
 (0)