3131GridVariable = grids .GridVariable
3232GridVariableVector = grids .GridVariableVector
3333FluxInterpFn = Callable [[GridVariable , GridVariableVector , float ], GridVariable ]
34+ BoundaryConditions = boundaries .BoundaryConditions
3435
3536
3637def default (value , d ):
38+ """Returns `value` if it is not None, otherwise returns `d` which is the default value."""
3739 return d if value is None else value
3840
3941
@@ -360,22 +362,17 @@ def __init__(
360362 self ,
361363 grid : Grid ,
362364 target_offset : Tuple [float , ...],
363- low_interp : FluxInterpFn = None ,
364- high_interp : FluxInterpFn = None ,
365+ low_interp : Optional [ FluxInterpFn ] = None ,
366+ high_interp : Optional [ FluxInterpFn ] = None ,
365367 limiter : Callable = van_leer_limiter ,
366368 ):
367369 super ().__init__ ()
368370 self .grid = grid
369- self .low_interp = (
370- Upwind (grid , target_offset = target_offset )
371- if low_interp is None
372- else low_interp
373- )
374- self .high_interp = (
375- LaxWendroff (grid , target_offset = target_offset )
376- if high_interp is None
377- else high_interp
371+ self .low_interp = default (low_interp , Upwind (grid , target_offset = target_offset ))
372+ self .high_interp = default (
373+ high_interp , LaxWendroff (grid , target_offset = target_offset )
378374 )
375+
379376 self .limiter = limiter
380377 self .target_offset = target_offset
381378
@@ -394,28 +391,28 @@ def forward(
394391 Returns:
395392 Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.
396393 """
397- for axis , axis_offset in enumerate (self .target_offset ):
394+ for dim , offset in enumerate (self .target_offset ):
398395 interpolation_offset = tuple (
399396 [
400- c_offset if i != axis else axis_offset
397+ c_offset if i != dim else offset
401398 for i , c_offset in enumerate (c .offset )
402399 ]
403400 )
404401 if interpolation_offset != c .offset :
405- if interpolation_offset [axis ] - c .offset [axis ] != 0.5 :
402+ if interpolation_offset [dim ] - c .offset [dim ] != 0.5 :
406403 raise NotImplementedError (
407404 "Only forward interpolation to control volume faces is supported."
408405 )
409406 c_low = self .low_interp (c , v , dt )
410407 c_high = self .high_interp (c , v , dt )
411- c_left = c .shift (- 1 , axis ).data
412- c_right = c .shift (1 , axis ).data
413- c_next_right = c .shift (2 , axis ).data
408+ c_left = c .shift (- 1 , dim ).data
409+ c_right = c .shift (1 , dim ).data
410+ c_next_right = c .shift (2 , dim ).data
414411 pos_r = safe_div (c - c_left , c_right - c )
415412 neg_r = safe_div (c_next_right - c_right , c_right - c )
416413 pos_phi = self .limiter (pos_r ).data
417414 neg_phi = self .limiter (neg_r ).data
418- u = v [axis ]
415+ u = v [dim ]
419416 phi = torch .where (u > 0 , pos_phi , neg_phi )
420417 interpolated = c_low - (c_low - c_high ) * phi
421418 c = GridVariable (interpolated .data , interpolation_offset , c .grid , c .bc )
@@ -455,8 +452,8 @@ def __init__(
455452 self ,
456453 grid : Grid ,
457454 offset : Tuple [float , ...],
458- bc_c : boundaries . BoundaryConditions ,
459- bc_v : Tuple [boundaries .BoundaryConditions , ...],
455+ bc_c : Optional [ BoundaryConditions ] = None ,
456+ bc_v : Optional [ Tuple [boundaries .BoundaryConditions , ...]] = None ,
460457 limiter : Optional [Callable ] = None ,
461458 ) -> None :
462459 super ().__init__ ()
@@ -538,13 +535,8 @@ def __init__(
538535 self ,
539536 grid : Grid ,
540537 offset = (0.5 , 0.5 ),
541- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
542- ndim = 2
543- ),
544- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
545- boundaries .periodic_boundary_conditions (ndim = 2 ),
546- boundaries .periodic_boundary_conditions (ndim = 2 ),
547- ),
538+ bc_c : Optional [BoundaryConditions ] = None ,
539+ bc_v : Optional [Tuple [boundaries .BoundaryConditions , ...]] = None ,
548540 ** kwargs ,
549541 ):
550542 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -578,13 +570,8 @@ def __init__(
578570 self ,
579571 grid : Grid ,
580572 offset : Tuple [float , ...] = (0.5 , 0.5 ),
581- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
582- ndim = 2
583- ),
584- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
585- boundaries .periodic_boundary_conditions (ndim = 2 ),
586- boundaries .periodic_boundary_conditions (ndim = 2 ),
587- ),
573+ bc_c : Optional [BoundaryConditions ] = None ,
574+ bc_v : Optional [Tuple [boundaries .BoundaryConditions , ...]] = None ,
588575 ** kwargs ,
589576 ):
590577 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -617,13 +604,8 @@ def __init__(
617604 self ,
618605 grid : Grid ,
619606 offset : Tuple [float , ...] = (0.5 , 0.5 ),
620- bc_c : boundaries .BoundaryConditions = boundaries .periodic_boundary_conditions (
621- ndim = 2
622- ),
623- bc_v : Tuple [boundaries .BoundaryConditions , ...] = (
624- boundaries .periodic_boundary_conditions (ndim = 2 ),
625- boundaries .periodic_boundary_conditions (ndim = 2 ),
626- ),
607+ bc_c : Optional [BoundaryConditions ] = None ,
608+ bc_v : Optional [Tuple [boundaries .BoundaryConditions , ...]] = None ,
627609 limiter : Callable = van_leer_limiter ,
628610 ** kwargs ,
629611 ):
@@ -672,16 +654,21 @@ def __init__(
672654 self ,
673655 grid : Grid ,
674656 offsets : Tuple [Tuple [float , ...], ...] = ((1.0 , 0.5 ), (0.5 , 1.0 )),
675- bcs : Tuple [boundaries .BoundaryConditions , ...] = (
676- boundaries .periodic_boundary_conditions (ndim = 2 ),
677- boundaries .periodic_boundary_conditions (ndim = 2 ),
678- ),
657+ bcs : Optional [Tuple [boundaries .BoundaryConditions , ...]] = None ,
679658 advect : type [nn .Module ] = AdvectionVanLeer ,
680659 limiter : Callable = van_leer_limiter ,
681660 ** kwargs ,
682661 ):
683662 super ().__init__ ()
684-
663+ self .grid = grid
664+ self .offsets = offsets
665+ bcs = default (
666+ bcs ,
667+ tuple (
668+ boundaries .periodic_boundary_conditions (ndim = grid .ndim )
669+ for _ in range (grid .ndim )
670+ ),
671+ )
685672 self .advect = nn .ModuleList (
686673 advect (
687674 grid = grid ,
0 commit comments