@@ -234,8 +234,8 @@ class AdvectAligned(nn.Module):
234234 def __init__ (
235235 self ,
236236 grid : Grid ,
237- bcs_c : Tuple [boundaries . BoundaryConditions , ...],
238- bcs_v : Tuple [boundaries . BoundaryConditions , ...],
237+ bcs_c : Tuple [BoundaryConditions , ...],
238+ bcs_v : Tuple [BoundaryConditions , ...],
239239 offsets : Tuple [Tuple [float , ...], ...] = ((1.0 , 0.5 ), (0.5 , 1.0 )),
240240 ** kwargs ,
241241 ):
@@ -320,7 +320,7 @@ def __init__(
320320 self ,
321321 grid : Grid ,
322322 target_offset : Tuple [float , ...] = (0.5 , 0.5 ),
323- bc : Optional [boundaries . BoundaryConditions ] = None ,
323+ bc : Optional [BoundaryConditions ] = None ,
324324 ** kwargs ,
325325 ):
326326 super ().__init__ ()
@@ -453,7 +453,7 @@ def __init__(
453453 grid : Grid ,
454454 offset : Tuple [float , ...],
455455 bc_c : Optional [BoundaryConditions ] = None ,
456- bc_v : Optional [Tuple [boundaries . BoundaryConditions , ...]] = None ,
456+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
457457 limiter : Optional [Callable ] = None ,
458458 ) -> None :
459459 super ().__init__ ()
@@ -469,8 +469,12 @@ def __init__(
469469 for _ in range (grid .ndim )
470470 ),
471471 )
472+ self .bc_v = bc_v
472473 self .advect_aligned = AdvectAligned (
473- grid = grid , bcs_c = (bc_c , bc_c ), bcs_v = bc_v , offsets = self .target_offsets
474+ grid = grid ,
475+ bcs_c = (bc_c ,) * grid .ndim ,
476+ bcs_v = bc_v ,
477+ offsets = self .target_offsets ,
474478 )
475479 self ._flux_interp = nn .ModuleList () # placeholder
476480 self ._velocity_interp = nn .ModuleList () # placeholder
@@ -536,7 +540,7 @@ def __init__(
536540 grid : Grid ,
537541 offset = (0.5 , 0.5 ),
538542 bc_c : Optional [BoundaryConditions ] = None ,
539- bc_v : Optional [Tuple [boundaries . BoundaryConditions , ...]] = None ,
543+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
540544 ** kwargs ,
541545 ):
542546 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -571,7 +575,7 @@ def __init__(
571575 grid : Grid ,
572576 offset : Tuple [float , ...] = (0.5 , 0.5 ),
573577 bc_c : Optional [BoundaryConditions ] = None ,
574- bc_v : Optional [Tuple [boundaries . BoundaryConditions , ...]] = None ,
578+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
575579 ** kwargs ,
576580 ):
577581 super ().__init__ (grid , offset , bc_c , bc_v )
@@ -605,7 +609,7 @@ def __init__(
605609 grid : Grid ,
606610 offset : Tuple [float , ...] = (0.5 , 0.5 ),
607611 bc_c : Optional [BoundaryConditions ] = None ,
608- bc_v : Optional [Tuple [boundaries . BoundaryConditions , ...]] = None ,
612+ bc_v : Optional [Tuple [BoundaryConditions , ...]] = None ,
609613 limiter : Callable = van_leer_limiter ,
610614 ** kwargs ,
611615 ):
@@ -621,7 +625,7 @@ def __init__(
621625
622626 self ._velocity_interp = nn .ModuleList (
623627 LinearInterpolation (grid , target_offset = offset , bc = bc )
624- for offset , bc in zip (self .target_offsets , bc_v )
628+ for offset , bc in zip (self .target_offsets , self . bc_v )
625629 )
626630
627631
@@ -654,7 +658,7 @@ def __init__(
654658 self ,
655659 grid : Grid ,
656660 offsets : Tuple [Tuple [float , ...], ...] = ((1.0 , 0.5 ), (0.5 , 1.0 )),
657- bcs : Optional [Tuple [boundaries . BoundaryConditions , ...]] = None ,
661+ bcs : Optional [Tuple [BoundaryConditions , ...]] = None ,
658662 advect : type [nn .Module ] = AdvectionVanLeer ,
659663 limiter : Callable = van_leer_limiter ,
660664 ** kwargs ,
0 commit comments