@@ -646,32 +646,32 @@ def _interior_grid(self) -> Grid:
646646 grid = self .grid
647647 domain = list (grid .domain )
648648 shape = list (grid .shape )
649- for axis in range (self .grid .ndim ):
649+ for dim in range (self .grid .ndim ):
650650 # nothing happens in periodic case
651- if self .bc .types [axis ][1 ] == "periodic" :
651+ if self .bc .types [dim ][1 ] == "periodic" :
652652 continue
653653 # nothing happens if the offset is not 0.0 or 1.0
654654 # this will automatically set the grid to interior.
655- if math .isclose (self .offset [axis ], 1.0 ):
656- shape [axis ] -= 1
657- domain [axis ] = (domain [axis ][0 ], domain [axis ][1 ] - grid .step [axis ])
658- elif math .isclose (self .offset [axis ], 0.0 ):
659- shape [axis ] -= 1
660- domain [axis ] = (domain [axis ][0 ] + grid .step [axis ], domain [axis ][1 ])
655+ if math .isclose (self .offset [dim ], 1.0 ):
656+ shape [dim ] -= 1
657+ domain [dim ] = (domain [dim ][0 ], domain [dim ][1 ] - grid .step [dim ])
658+ elif math .isclose (self .offset [dim ], 0.0 ):
659+ shape [dim ] -= 1
660+ domain [dim ] = (domain [dim ][0 ] + grid .step [dim ], domain [dim ][1 ])
661661 return Grid (shape , domain = tuple (domain ))
662662
663663 def _interior_array (self ) -> torch .Tensor :
664664 """Returns only the interior points of self.data."""
665665 data = self .data
666- for axis in range (self .grid .ndim ):
666+ for dim in range (self .grid .ndim ):
667667 # nothing happens in periodic case
668- if self .bc .types [axis ][1 ] == "periodic" :
668+ if self .bc .types [dim ][1 ] == "periodic" :
669669 continue
670670 # nothing happens if the offset is not 0.0 or 1.0
671- if math .isclose (self .offset [axis ], 1.0 ):
672- data , _ = tensor_utils .split_along_axis (data , - 1 , axis )
673- elif math .isclose (self .offset [axis ], 0.0 ):
674- _ , data = tensor_utils .split_along_axis (data , 1 , axis )
671+ if math .isclose (self .offset [dim ], 1.0 ):
672+ data , _ = tensor_utils .split_along_axis (data , - 1 , dim )
673+ elif math .isclose (self .offset [dim ], 0.0 ):
674+ _ , data = tensor_utils .split_along_axis (data , 1 , dim )
675675
676676 return data
677677
@@ -689,7 +689,10 @@ def interior(self) -> GridVariable:
689689
690690 In case of dirichlet with edge offset, the grid and array size is reduced,
691691 since one scalar lies exactly on the boundary. In all other cases,
692- self.grid and self.array are returned.
692+ self.grid and self.data are returned.
693+
694+ Porting notes:
695+ This method actually does not check whether the boundary conditions are imposed or not, is purely determined by the offset.
693696 """
694697 interior_array = self ._interior_array ()
695698 interior_grid = self ._interior_grid ()
@@ -703,22 +706,22 @@ def enforce_edge_bc(self, *args) -> GridVariable:
703706 enforce_edge_bc() changes these boundary values to match the prescribed BC.
704707
705708 Args:
706- *args: any optional values passed into BoundaryConditions values method.
709+ *args: any optional values passed into BoundaryConditions values method.
707710 """
708711 if self .grid .shape != self .data .shape :
709712 raise ValueError ("Stored array and grid have mismatched sizes." )
710- data = torch .as_tensor (self .data )
711- for axis in range (self .grid .ndim ):
712- if "periodic" not in self .bc .types [axis ]:
713- values = self .bc .values (axis , self .grid , * args )
713+ data = torch .as_tensor (self .data ). clone () # Clone to avoid modifying original
714+ for dim in range (self .grid .ndim ):
715+ if "periodic" not in self .bc .types [dim ]:
716+ values = self .bc .values (dim , self .grid , * args )
714717 for boundary_side in range (2 ):
715- if torch .isclose (self .offset [axis ], boundary_side ):
718+ if math .isclose (self .offset [dim ], boundary_side ):
716719 # boundary data is set to match self.bc:
717720 all_slice = [
718721 slice (None , None , None ),
719722 ] * self .grid .ndim
720- all_slice [axis ] = - boundary_side
721- data = data . at [tuple (all_slice )]. set ( values [boundary_side ])
723+ all_slice [dim ] = - boundary_side
724+ data [tuple (all_slice )] = values [boundary_side ]
722725 return GridVariable (data , self .offset , self .grid , self .bc )
723726
724727
@@ -945,15 +948,15 @@ def pad(
945948 else :
946949 raise ValueError (
947950 "expected the new offset to be an edge or cell center, got "
948- f"offset[axis ]={ u .offset [dim ]} "
951+ f"offset[dim ]={ u .offset [dim ]} "
949952 )
950953 elif bc_type == BCType .NEUMANN :
951954 if not (
952955 math .isclose (u .offset [dim ] % 1 , 0 ) or math .isclose (u .offset [dim ] % 1 , 0.5 )
953956 ):
954957 raise ValueError (
955958 "expected offset to be an edge or cell center, got "
956- f"offset[axis ]={ u .offset [dim ]} "
959+ f"offset[dim ]={ u .offset [dim ]} "
957960 )
958961 else :
959962 # When the data is cell-centered, computes the backward difference.
@@ -1129,61 +1132,6 @@ def _constant_pad_tensor(
11291132 return result
11301133
11311134
1132- # def _constant_pad(
1133- # inputs: torch.Tensor,
1134- # pad: Tuple[Tuple[int, int], ...],
1135- # constant_values: Tuple[Tuple[float, float], ...],
1136- # **kwargs,
1137- # ) -> torch.Tensor:
1138- # """
1139- # Corrected padding function that supports different constant values for each side.
1140- # Pads each dimension from first to last as per the user input, mapping correctly to
1141- # PyTorch's last-to-first padding order.
1142- # inputs was unsqueezed at dim 0 as a batch_dim, so actual data dims are shifted by +1
1143- # """
1144- # ndim = inputs.ndim - 1 #
1145- # original_shape = list(inputs.shape)
1146- # out_shape = [1] + [original_shape[i+1] + pad[i][0] + pad[i][1] for i in range(ndim)] # inputs was unsqueezed at dim 0, so actual data dims are shifted by +1
1147-
1148- # output = torch.empty(out_shape, dtype=inputs.dtype, device=inputs.device)
1149-
1150- # def get_vals(dim):
1151- # if len(constant_values) > dim:
1152- # vals = constant_values[dim]
1153- # if isinstance(vals, (tuple, list)) and len(vals) == 2:
1154- # return float(vals[0]), float(vals[1])
1155- # else:
1156- # val = float(vals)
1157- # return val, val
1158- # return 0.0, 0.0
1159-
1160- # # Fill with zeros initially
1161- # output.fill_(0.0)
1162-
1163- # # Main region
1164- # slices = (slice(None),) + tuple(slice(pad[i][0], pad[i][0] + original_shape[i+1]) for i in range(ndim))
1165- # output[slices] = inputs
1166-
1167- # # Apply left/right pad values per dim
1168- # for i in range(ndim):
1169- # lpad, rpad = pad[i]
1170- # lval, rval = get_vals(i)
1171-
1172- # if lpad > 0:
1173- # left_slices = [slice(None)] * ndim
1174- # left_slices[i] = slice(0, lpad)
1175- # left_slides = (slice(None), ) + tuple(left_slices)
1176- # output[left_slides] = lval
1177-
1178- # if rpad > 0:
1179- # right_slices = [slice(None)] * ndim
1180- # right_slices[i] = slice(-rpad, None)
1181- # right_slices = (slice(None), ) + tuple(right_slices)
1182- # output[right_slices] = rval
1183-
1184- # return output
1185-
1186-
11871135def averaged_offset (* offsets : List [Tuple [float , ...]]) -> Tuple [float , ...]:
11881136 """Returns the averaged offset of the given arrays."""
11891137 n = len (offsets )
0 commit comments