Skip to content

Commit 056ff6d

Browse files
committed
fixed advection init error for 1d
1 parent 8c88e1f commit 056ff6d

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

torch_cfd/advection.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ class AdvectAligned(nn.Module):
234234
def __init__(
235235
self,
236236
grid: Grid,
237-
bcs_c: Tuple[boundaries.BoundaryConditions, ...],
238-
bcs_v: Tuple[boundaries.BoundaryConditions, ...],
237+
bcs_c: Tuple[BoundaryConditions, ...],
238+
bcs_v: Tuple[BoundaryConditions, ...],
239239
offsets: Tuple[Tuple[float, ...], ...] = ((1.0, 0.5), (0.5, 1.0)),
240240
**kwargs,
241241
):
@@ -320,7 +320,7 @@ def __init__(
320320
self,
321321
grid: Grid,
322322
target_offset: Tuple[float, ...] = (0.5, 0.5),
323-
bc: Optional[boundaries.BoundaryConditions] = None,
323+
bc: Optional[BoundaryConditions] = None,
324324
**kwargs,
325325
):
326326
super().__init__()
@@ -453,7 +453,7 @@ def __init__(
453453
grid: Grid,
454454
offset: Tuple[float, ...],
455455
bc_c: Optional[BoundaryConditions] = None,
456-
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
456+
bc_v: Optional[Tuple[BoundaryConditions, ...]] = None,
457457
limiter: Optional[Callable] = None,
458458
) -> None:
459459
super().__init__()
@@ -469,8 +469,12 @@ def __init__(
469469
for _ in range(grid.ndim)
470470
),
471471
)
472+
self.bc_v = bc_v
472473
self.advect_aligned = AdvectAligned(
473-
grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets
474+
grid=grid,
475+
bcs_c=(bc_c,) * grid.ndim,
476+
bcs_v=bc_v,
477+
offsets=self.target_offsets,
474478
)
475479
self._flux_interp = nn.ModuleList() # placeholder
476480
self._velocity_interp = nn.ModuleList() # placeholder
@@ -536,7 +540,7 @@ def __init__(
536540
grid: Grid,
537541
offset=(0.5, 0.5),
538542
bc_c: Optional[BoundaryConditions] = None,
539-
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
543+
bc_v: Optional[Tuple[BoundaryConditions, ...]] = None,
540544
**kwargs,
541545
):
542546
super().__init__(grid, offset, bc_c, bc_v)
@@ -571,7 +575,7 @@ def __init__(
571575
grid: Grid,
572576
offset: Tuple[float, ...] = (0.5, 0.5),
573577
bc_c: Optional[BoundaryConditions] = None,
574-
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
578+
bc_v: Optional[Tuple[BoundaryConditions, ...]] = None,
575579
**kwargs,
576580
):
577581
super().__init__(grid, offset, bc_c, bc_v)
@@ -605,7 +609,7 @@ def __init__(
605609
grid: Grid,
606610
offset: Tuple[float, ...] = (0.5, 0.5),
607611
bc_c: Optional[BoundaryConditions] = None,
608-
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
612+
bc_v: Optional[Tuple[BoundaryConditions, ...]] = None,
609613
limiter: Callable = van_leer_limiter,
610614
**kwargs,
611615
):
@@ -621,7 +625,7 @@ def __init__(
621625

622626
self._velocity_interp = nn.ModuleList(
623627
LinearInterpolation(grid, target_offset=offset, bc=bc)
624-
for offset, bc in zip(self.target_offsets, bc_v)
628+
for offset, bc in zip(self.target_offsets, self.bc_v)
625629
)
626630

627631

@@ -654,7 +658,7 @@ def __init__(
654658
self,
655659
grid: Grid,
656660
offsets: Tuple[Tuple[float, ...], ...] = ((1.0, 0.5), (0.5, 1.0)),
657-
bcs: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
661+
bcs: Optional[Tuple[BoundaryConditions, ...]] = None,
658662
advect: type[nn.Module] = AdvectionVanLeer,
659663
limiter: Callable = van_leer_limiter,
660664
**kwargs,

0 commit comments

Comments
 (0)