1818
1919import dataclasses
2020import math
21- from typing import Optional , Sequence , Tuple
21+ from typing import Optional , Sequence , Tuple , Union
2222
2323import torch
2424
3131
3232
3333BCType = grids .BCType ()
34+ Padding = grids .Padding ()
3435
3536
36- class Padding :
37- MIRROR = "reflect"
38- EXTEND = "replicate"
39-
40-
41- @dataclasses .dataclass (init = False , frozen = True )
42- class ConstantBoundaryConditions (grids .BoundaryConditions ):
37+ @dataclasses .dataclass (init = False , frozen = True , repr = False )
38+ class ConstantBoundaryConditions (BoundaryConditions ):
4339 """Boundary conditions for a PDE variable that are constant in space and time.
4440
4541 Attributes:
@@ -59,8 +55,8 @@ class ConstantBoundaryConditions(grids.BoundaryConditions):
5955
6056 """
6157
62- types : Tuple [Tuple [str , str ], ...]
63- _values : Tuple [Tuple [Optional [float ], Optional [float ]], ...]
58+ _types : Tuple [Tuple [str , str ], ...]
59+ bc_values : Tuple [Tuple [Optional [float ], Optional [float ]], ...]
6460
6561 def __init__ (
6662 self ,
@@ -69,10 +65,62 @@ def __init__(
6965 ):
7066 types = tuple (types )
7167 values = tuple (values )
72- object .__setattr__ (self , "types " , types )
73- object .__setattr__ (self , "_values " , values )
68+ object .__setattr__ (self , "_types " , types )
69+ object .__setattr__ (self , "bc_values " , values )
7470 object .__setattr__ (self , "ndim" , len (types ))
7571
72+ @property
73+ def types (self ) -> Tuple [Tuple [str , str ], ...]:
74+ """Returns the boundary condition types."""
75+ return self ._types
76+
77+ @types .setter
78+ def types (self , bc_types : Sequence [Tuple [str , str ]]) -> None :
79+ """Sets the boundary condition types and updates ndim accordingly."""
80+ bc_types = tuple (bc_types )
81+ assert self .ndim == len (
82+ bc_types
83+ ), f"Number of dimensions { self .ndim } does not match the number of types { bc_types } ."
84+ object .__setattr__ (self , "_types" , bc_types )
85+
86+ def __repr__ (self ) -> str :
87+ try :
88+ lines = [f"{ self .__class__ .__name__ } ({ self .ndim } D):" ]
89+
90+ for dim in range (self .ndim ):
91+ lower_type , upper_type = self .types [dim ]
92+ lower_val , upper_val = self .bc_values [dim ]
93+
94+ # Format values
95+ lower_val_str = "None" if lower_val is None else f"{ lower_val } "
96+ upper_val_str = "None" if upper_val is None else f"{ upper_val } "
97+
98+ lines .append (
99+ f" dim { dim } : [{ lower_type } ({ lower_val_str } ), { upper_type } ({ upper_val_str } )]"
100+ )
101+
102+ return "\n " .join (lines )
103+ except Exception as e :
104+ return f"{ self .__class__ .__name__ } not initialized: { e } "
105+
106+ def clone (
107+ self ,
108+ types : Optional [Sequence [Tuple [str , str ]]] = None ,
109+ values : Optional [Sequence [Tuple [Optional [float ], Optional [float ]]]] = None ,
110+ ) -> BoundaryConditions :
111+ """Creates a copy of this boundary condition, optionally with modified parameters.
112+
113+ Args:
114+ types: New boundary condition types. If None, uses current types.
115+ values: New boundary condition values. If None, uses current values.
116+
117+ Returns:
118+ A new ConstantBoundaryConditions instance.
119+ """
120+ new_types = types if types is not None else self .types
121+ new_values = values if values is not None else self .bc_values
122+ return ConstantBoundaryConditions (new_types , new_values )
123+
76124 def shift (
77125 self ,
78126 u : GridVariable ,
@@ -106,7 +154,9 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
106154 size_diff += 1
107155 if self .types [dim ][1 ] == BCType .DIRICHLET and math .isclose (u .offset [dim ], 1 ):
108156 size_diff += 1
109- if (self .types [dim ][0 ] == BCType .NEUMANN and math .isclose (u .offset [dim ], 1 )) or (self .types [dim ][1 ] == BCType .NEUMANN and math .isclose (u .offset [dim ], 0 )):
157+ if (
158+ self .types [dim ][0 ] == BCType .NEUMANN and math .isclose (u .offset [dim ], 1 )
159+ ) or (self .types [dim ][1 ] == BCType .NEUMANN and math .isclose (u .offset [dim ], 0 )):
110160 """
111161 if lower (or left for dim 0) is Neumann, and the offset is 1 (the variable is on the right edge of a cell), the Neumann bc is not defined; vice versa for upper (or right for dim 0) Neumann bc with offset 0 (the variable is on the left edge of a cell).
112162 """
@@ -117,6 +167,78 @@ def _is_aligned(self, u: GridVariable, dim: int) -> bool:
117167 )
118168 return True
119169
170+ def pad (
171+ self ,
172+ u : GridVariable ,
173+ width : Union [Tuple [int , int ], int ],
174+ dim : int ,
175+ mode : Optional [str ] = Padding .EXTEND ,
176+ ) -> GridVariable :
177+ """Wrapper for grids.pad with a specific bc.
178+
179+ Args:
180+ u: a `GridVariable` object.
181+ width: number of elements to pad along axis. If width is an int, use
182+ negative value for lower boundary or positive value for upper boundary.
183+ If a tuple, pads with width[0] on the left and width[1] on the right.
184+ dim: axis to pad along.
185+ mode: type of padding to use in non-periodic case.
186+ Mirror mirrors the array values across the boundary.
187+ Extend extends the last well-defined array value past the boundary.
188+
189+ Returns:
190+ Padded array, elongated along the indicated axis.
191+ the original u.bc will be replaced with self.
192+ """
193+ _ = self ._is_aligned (u , dim )
194+ if isinstance (width , tuple ) and (width [0 ] > 0 and width [1 ] > 0 ):
195+ need_trimming = "both"
196+ elif (isinstance (width , tuple ) and (width [0 ] > 0 and width [1 ] == 0 )) or (
197+ isinstance (width , int ) and width < 0
198+ ):
199+ need_trimming = "left"
200+ elif (isinstance (width , tuple ) and (width [0 ] == 0 and width [1 ] > 0 )) or (
201+ isinstance (width , int ) and width > 0
202+ ):
203+ need_trimming = "right"
204+ else :
205+ need_trimming = "none"
206+
207+ u , trimmed_padding = self ._trim_padding (u , dim , need_trimming )
208+
209+ if isinstance (width , int ):
210+ if width < 0 :
211+ width -= trimmed_padding [0 ]
212+ if width > 0 :
213+ width += trimmed_padding [1 ]
214+ elif isinstance (width , tuple ):
215+ width = (width [0 ] + trimmed_padding [0 ], width [1 ] + trimmed_padding [1 ])
216+
217+ u = grids .pad (u , width , dim , self , mode = mode )
218+ return u
219+
220+ def pad_all (
221+ self ,
222+ u : GridVariable ,
223+ width : Tuple [Tuple [int , int ], ...],
224+ mode : Optional [str ] = Padding .EXTEND ,
225+ ) -> GridVariable :
226+ """Pads along all axes with pad width specified by width tuple.
227+
228+ Args:
229+ u: a `GridArray` object.
230+ width: Tuple of padding width for each side for each axis.
231+ mode: type of padding to use in non-periodic case.
232+ Mirror mirrors the array values across the boundary.
233+ Extend extends the last well-defined array value past the boundary.
234+
235+ Returns:
236+ Padded array, elongated along all axes.
237+ """
238+ for dim in range (- u .grid .ndim , 0 ):
239+ u = self .pad (u , width [dim ], dim , mode = mode )
240+ return u
241+
120242 def values (
121243 self , dim : int , grid : Grid
122244 ) -> Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]:
@@ -130,21 +252,22 @@ def values(
130252 A tuple of arrays of grid.ndim - 1 dimensions that specify values on the
131253 boundary. In case of periodic boundaries, returns a tuple(None,None).
132254 """
133- if None in self ._values [dim ]:
255+ if None in self .bc_values [dim ]:
134256 return (None , None )
135-
257+
136258 bc = []
137259 for i in [0 , 1 ]:
138- value = self ._values [dim ][- i ]
260+ value = self .bc_values [dim ][- i ]
139261 if value is None :
140262 bc .append (None )
141263 else :
142264 bc .append (torch .full (grid .shape [:dim ] + grid .shape [dim + 1 :], value ))
143-
265+
144266 return tuple (bc )
145-
146267
147- def _trim_padding (self , u : GridVariable , dim : int = - 1 , trim_side : str = "both" ):
268+ def _trim_padding (
269+ self , u : GridVariable , dim : int = - 1 , trim_side : str = "both"
270+ ) -> Tuple [GridVariable , Tuple [int , int ]]:
148271 """Trims padding from a GridVariable along axis and returns the array interior.
149272
150273 Args:
@@ -160,6 +283,9 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
160283 negative_trim = 0
161284 padding = (0 , 0 )
162285
286+ if trim_side not in ("both" , "left" , "right" ):
287+ return u .array , padding
288+
163289 if u .shape [dim ] >= u .grid .shape [dim ]:
164290 # number of cells that were padded on the left
165291 negative_trim = 0
@@ -201,7 +327,7 @@ def _trim_padding(self, u: GridVariable, dim: int = -1, trim_side: str = "both")
201327 def trim_boundary (self , u : GridVariable ) -> GridVariable :
202328 """Returns GridVariable without the grid points on the boundary.
203329
204- Some grid points of GridVariable might coincide with boundary. This trims those values.
330+ Some grid points of GridVariable might coincide with boundary. This trims those values.
205331 If the array was padded beforehand, removes the padding.
206332
207333 Args:
@@ -223,33 +349,29 @@ def pad_and_impose_bc(
223349 ) -> GridVariable :
224350 """Returns GridVariable with correct boundary values.
225351
226- Some grid points of GridVariable might coincide with boundary. This ensures
227- that the GridVariable.array agrees with GridVariable.bc.
352+ Some grid points of GridVariable might coincide with boundary, thus this function is only used with the trimmed GridVariable.
228353 Args:
229- u: a `GridVariable` object that specifies only scalar values on the internal
230- nodes.
231- offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the
232- function is given just an interior array in dirichlet case, it can pad
233- to both 0 offset and 1 offset.
234- mode: type of padding to use in non-periodic case.
235- Mirror mirrors the flow across the boundary.
236- Extend extends the last well-defined value past the boundary.
354+ - u: a `GridVariable` object that specifies only scalar values on the internal nodes.
355+ - offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the function is given just an interior array in dirichlet case, it can pad to both 0 offset and 1 offset.
356+ - mode: type of padding to use in non-periodic case. Mirror mirrors the flow across the boundary. Extend extends the last well-defined value past the boundary.
237357
238358 Returns:
239359 A GridVariable that has correct boundary values.
240360 """
361+ assert u .bc is None , "u must be trimmed before padding and imposing bc."
241362 if offset_to_pad_to is None :
242363 offset_to_pad_to = u .offset
243- for axis in range (- u .grid .ndim , 0 ):
244- _ = self ._is_aligned (u , axis )
245- if self .types [axis ][0 ] == BCType .DIRICHLET and math .isclose (
246- u .offset [axis ], 1.0
247- ):
248- if math .isclose (offset_to_pad_to [axis ], 1.0 ):
249- u = grids .pad (u , 1 , axis , self )
250- elif math .isclose (offset_to_pad_to [axis ], 0.0 ):
251- u = grids .pad (u , - 1 , axis , self )
252- return GridVariable (u .data , u .offset , u .grid , self )
364+ for dim in range (- u .grid .ndim , 0 ):
365+ _ = self ._is_aligned (u , dim )
366+ if self .types [dim ][0 ] != BCType .PERIODIC :
367+ # if the offset is either 0 or 1, u is aligned with the boundary and is defined on cell edges on one side of the boundary, if trim_boundary is called before this function.
368+ # u needs to be padded on both sides
369+ # if the offset is 0.5, one ghost cell is needed on each side.
370+ # it will be taken care by grids.pad function automatically.
371+ u = grids .pad (u , (1 , 1 ), dim , self )
372+ elif self .types [dim ][0 ] == BCType .PERIODIC :
373+ return GridVariable (u .data , u .offset , u .grid , self )
374+ return u
253375
254376 def impose_bc (self , u : GridVariable ) -> GridVariable :
255377 """Returns GridVariable with correct boundary condition.
@@ -260,8 +382,7 @@ def impose_bc(self, u: GridVariable) -> GridVariable:
260382 u: a `GridVariable` object.
261383
262384 Returns:
263- A GridVariable that has correct boundary values and is restricted to the
264- domain.
385+ A GridVariable that has correct boundary values with ghost cells added on the other side of DoFs living at cell center if the bc is Dirichlet or Neumann.
265386 """
266387 offset = u .offset
267388 u = self .trim_boundary (u )
@@ -317,7 +438,7 @@ def periodic_boundary_conditions(ndim: int) -> BoundaryConditions:
317438
318439def dirichlet_boundary_conditions (
319440 ndim : int ,
320- bc_vals : Optional [Sequence [Tuple [float , float ]]] = None ,
441+ bc_values : Optional [Sequence [Tuple [float , float ]]] = None ,
321442) -> BoundaryConditions :
322443 """Returns Dirichelt BCs for a variable with `ndim` spatial dimension.
323444
@@ -329,13 +450,13 @@ def dirichlet_boundary_conditions(
329450 Returns:
330451 BoundaryCondition subclass.
331452 """
332- if not bc_vals :
453+ if not bc_values :
333454 return HomogeneousBoundaryConditions (
334455 ((BCType .DIRICHLET , BCType .DIRICHLET ),) * ndim
335456 )
336457 else :
337458 return ConstantBoundaryConditions (
338- ((BCType .DIRICHLET , BCType .DIRICHLET ),) * ndim , bc_vals
459+ ((BCType .DIRICHLET , BCType .DIRICHLET ),) * ndim , bc_values
339460 )
340461
341462
@@ -556,7 +677,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
556677 if isinstance (u_bc , HomogeneousBoundaryConditions ):
557678 u_values = tuple ((0.0 , 0.0 ) for _ in range (u_bc .ndim ))
558679 elif isinstance (u_bc , ConstantBoundaryConditions ):
559- u_values = u_bc ._values
680+ u_values = u_bc .bc_values
560681 else :
561682 raise NotImplementedError (
562683 f"Flux boundary condition is not implemented for velocity with { type (u_bc )} "
@@ -565,7 +686,7 @@ def get_advection_flux_bc_from_velocity_and_scalar_bc(
565686 if isinstance (c_bc , HomogeneousBoundaryConditions ):
566687 c_values = tuple ((0.0 , 0.0 ) for _ in range (c_bc .ndim ))
567688 elif isinstance (c_bc , ConstantBoundaryConditions ):
568- c_values = c_bc ._values
689+ c_values = c_bc .bc_values
569690 else :
570691 raise NotImplementedError (
571692 f"Flux boundary condition is not implemented for scalar with { type (c_bc )} "
0 commit comments