3131GridVariable = grids .GridVariable
3232GridVariableVector = grids .GridVariableVector
3333FluxInterpFn = Callable [[GridVariable , GridVariableVector , float ], GridVariable ]
34+ BoundaryConditions = boundaries .BoundaryConditions
3435
3536
3637def default (value , d ):
38+ """Returns `value` if it is not None, otherwise returns `d` which is the default value."""
3739 return d if value is None else value
3840
3941
@@ -232,8 +234,8 @@ class AdvectAligned(nn.Module):
232234 def __init__ (
233235 self ,
234236 grid : Grid ,
235- bcs_c : Tuple [boundaries . BoundaryConditions , ...],
236- bcs_v : Tuple [boundaries . BoundaryConditions , ...],
237+ bcs_c : Tuple [BoundaryConditions , ...],
238+ bcs_v : Tuple [BoundaryConditions , ...],
237239 offsets : Tuple [Tuple [float , ...], ...] = ((1.0 , 0.5 ), (0.5 , 1.0 )),
238240 ** kwargs ,
239241 ):
@@ -318,7 +320,7 @@ def __init__(
318320 self ,
319321 grid : Grid ,
320322 target_offset : Tuple [float , ...] = (0.5 , 0.5 ),
321- bc : Optional [boundaries . BoundaryConditions ] = None ,
323+ bc : Optional [BoundaryConditions ] = None ,
322324 ** kwargs ,
323325 ):
324326 super ().__init__ ()
@@ -360,22 +362,17 @@ def __init__(
360362 self ,
361363 grid : Grid ,
362364 target_offset : Tuple [float , ...],
363- low_interp : FluxInterpFn = None ,
364- high_interp : FluxInterpFn = None ,
365+ low_interp : Optional [ FluxInterpFn ] = None ,
366+ high_interp : Optional [ FluxInterpFn ] = None ,
365367 limiter : Callable = van_leer_limiter ,
366368 ):
367369 super ().__init__ ()
368370 self .grid = grid
369- self .low_interp = (
370- Upwind (grid , target_offset = target_offset )
371- if low_interp is None
372- else low_interp
373- )
374- self .high_interp = (
375- LaxWendroff (grid , target_offset = target_offset )
376- if high_interp is None
377- else high_interp
371+ self .low_interp = default (low_interp , Upwind (grid , target_offset = target_offset ))
372+ self .high_interp = default (
373+ high_interp , LaxWendroff (grid , target_offset = target_offset )
378374 )
375+
379376 self .limiter = limiter
380377 self .target_offset = target_offset
381378
@@ -394,28 +391,28 @@ def forward(
394391 Returns:
395392 Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.
396393 """
397- for axis , axis_offset in enumerate (self .target_offset ):
394+ for dim , offset in enumerate (self .target_offset ):
398395 interpolation_offset = tuple (
399396 [
400- c_offset if i != axis else axis_offset
397+ c_offset if i != dim else offset
401398 for i , c_offset in enumerate (c .offset )
402399 ]
403400 )
404401 if interpolation_offset != c .offset :
405- if interpolation_offset [axis ] - c .offset [axis ] != 0.5 :
402+ if interpolation_offset [dim ] - c .offset [dim ] != 0.5 :
406403 raise NotImplementedError (
407404 "Only forward interpolation to control volume faces is supported."
408405 )
409406 c_low = self .low_interp (c , v , dt )
410407 c_high = self .high_interp (c , v , dt )
411- c_left = c .shift (- 1 , axis ).data
412- c_right = c .shift (1 , axis ).data
413- c_next_right = c .shift (2 , axis ).data
408+ c_left = c .shift (- 1 , dim ).data
409+ c_right = c .shift (1 , dim ).data
410+ c_next_right = c .shift (2 , dim ).data
414411 pos_r = safe_div (c - c_left , c_right - c )
415412 neg_r = safe_div (c_next_right - c_right , c_right - c )
416413 pos_phi = self .limiter (pos_r ).data
417414 neg_phi = self .limiter (neg_r ).data
418- u = v [axis ]
415+ u = v [dim ]
419416 phi = torch .where (u > 0 , pos_phi , neg_phi )
420417 interpolated = c_low - (c_low - c_high ) * phi
421418 c = GridVariable (interpolated .data , interpolation_offset , c .grid , c .bc )
@@ -455,8 +452,8 @@ def __init__(
455452 self ,
456453 grid : Grid ,
457454 offset : Tuple [float , ...],
458- bc_c : boundaries . BoundaryConditions ,
459- bc_v : Tuple [boundaries . BoundaryConditions , ...],
455+ bc_c : Optional [ BoundaryConditions ] = None ,
456+ bc_v : Optional [ Tuple [BoundaryConditions , ...]] = None ,
460457 limiter : Optional [Callable ] = None ,
461458 ) -> None :
462459 super ().__init__ ()
@@ -472,8 +469,12 @@ def __init__(
472469 for _ in range (grid .ndim )
473470 ),
474471 )
472+ self .bc_v = bc_v
475473 self .advect_aligned = AdvectAligned (
476- grid = grid , bcs_c = (bc_c , bc_c ), bcs_v = bc_v , offsets = self .target_offsets
474+ grid = grid ,
475+ bcs_c = (bc_c ,) * grid .ndim ,
476+ bcs_v = bc_v ,
477+ offsets = self .target_offsets ,
477478 )
478479 self ._flux_interp = nn .ModuleList () # placeholder
479480 self ._velocity_interp = nn .ModuleList () # placeholder
@@ -538,13 +539,8 @@ def __init__(
538539 self ,
539540 grid : Grid ,
540541 offset = (0.5 , 0.5 ),
541- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
542- ndim = 2
543- ),
544- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
545- boundaries .periodic_boundary_conditions (ndim = 2 ),
546- boundaries .periodic_boundary_conditions (ndim = 2 ),
547- ),
542+ bc_c : Optional [BoundaryConditions ] = None ,
543+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
548544 ** kwargs ,
549545 ):
550546 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -578,13 +574,8 @@ def __init__(
578574 self ,
579575 grid : Grid ,
580576 offset : Tuple [float , ...] = (0.5 , 0.5 ),
581- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
582- ndim = 2
583- ),
584- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
585- boundaries .periodic_boundary_conditions (ndim = 2 ),
586- boundaries .periodic_boundary_conditions (ndim = 2 ),
587- ),
577+ bc_c : Optional [BoundaryConditions ] = None ,
578+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
588579 ** kwargs ,
589580 ):
590581 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -617,13 +608,8 @@ def __init__(
617608 self ,
618609 grid : Grid ,
619610 offset : Tuple [float , ...] = (0.5 , 0.5 ),
620- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
621- ndim = 2
622- ),
623- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
624- boundaries .periodic_boundary_conditions (ndim = 2 ),
625- boundaries .periodic_boundary_conditions (ndim = 2 ),
626- ),
611+ bc_c : Optional [BoundaryConditions ] = None ,
612+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
627613 limiter : Callable = van_leer_limiter ,
628614 ** kwargs ,
629615 ):
@@ -639,7 +625,7 @@ def __init__(
639625
640626 self ._velocity_interp = nn .ModuleList (
641627 LinearInterpolation (grid , target_offset = offset , bc = bc )
642- for offset , bc in zip (self .target_offsets , bc_v )
628+ for offset , bc in zip (self .target_offsets , self . bc_v )
643629 )
644630
645631
@@ -672,16 +658,21 @@ def __init__(
672658 self ,
673659 grid : Grid ,
674660 offsets : Tuple [Tuple [float , ...], ...] = ((1.0 , 0.5 ), (0.5 , 1.0 )),
675- bcs : Tuple [boundaries .BoundaryConditions , ...] = (
676- boundaries .periodic_boundary_conditions (ndim = 2 ),
677- boundaries .periodic_boundary_conditions (ndim = 2 ),
678- ),
661+ bcs : Optional [Tuple [BoundaryConditions , ...]] = None ,
679662 advect : type [nn .Module ] = AdvectionVanLeer ,
680663 limiter : Callable = van_leer_limiter ,
681664 ** kwargs ,
682665 ):
683666 super ().__init__ ()
684-
667+ self .grid = grid
668+ self .offsets = offsets
669+ bcs = default (
670+ bcs ,
671+ tuple (
672+ boundaries .periodic_boundary_conditions (ndim = grid .ndim )
673+ for _ in range (grid .ndim )
674+ ),
675+ )
685676 self .advect = nn .ModuleList (
686677 advect (
687678 grid = grid ,
0 commit comments