Skip to content

Commit 8c88e1f

Browse files
committed
added an uniform forcing for channel flow
1 parent 282be19 commit 8c88e1f

File tree

3 files changed

+209
-82
lines changed

3 files changed

+209
-82
lines changed

torch_cfd/advection.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
GridVariable = grids.GridVariable
3232
GridVariableVector = grids.GridVariableVector
3333
FluxInterpFn = Callable[[GridVariable, GridVariableVector, float], GridVariable]
34+
BoundaryConditions = boundaries.BoundaryConditions
3435

3536

3637
def default(value, d):
38+
"""Returns `value` if it is not None, otherwise returns `d` which is the default value."""
3739
return d if value is None else value
3840

3941

@@ -360,22 +362,17 @@ def __init__(
360362
self,
361363
grid: Grid,
362364
target_offset: Tuple[float, ...],
363-
low_interp: FluxInterpFn = None,
364-
high_interp: FluxInterpFn = None,
365+
low_interp: Optional[FluxInterpFn] = None,
366+
high_interp: Optional[FluxInterpFn] = None,
365367
limiter: Callable = van_leer_limiter,
366368
):
367369
super().__init__()
368370
self.grid = grid
369-
self.low_interp = (
370-
Upwind(grid, target_offset=target_offset)
371-
if low_interp is None
372-
else low_interp
373-
)
374-
self.high_interp = (
375-
LaxWendroff(grid, target_offset=target_offset)
376-
if high_interp is None
377-
else high_interp
371+
self.low_interp = default(low_interp, Upwind(grid, target_offset=target_offset))
372+
self.high_interp = default(
373+
high_interp, LaxWendroff(grid, target_offset=target_offset)
378374
)
375+
379376
self.limiter = limiter
380377
self.target_offset = target_offset
381378

@@ -394,28 +391,28 @@ def forward(
394391
Returns:
395392
Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.
396393
"""
397-
for axis, axis_offset in enumerate(self.target_offset):
394+
for dim, offset in enumerate(self.target_offset):
398395
interpolation_offset = tuple(
399396
[
400-
c_offset if i != axis else axis_offset
397+
c_offset if i != dim else offset
401398
for i, c_offset in enumerate(c.offset)
402399
]
403400
)
404401
if interpolation_offset != c.offset:
405-
if interpolation_offset[axis] - c.offset[axis] != 0.5:
402+
if interpolation_offset[dim] - c.offset[dim] != 0.5:
406403
raise NotImplementedError(
407404
"Only forward interpolation to control volume faces is supported."
408405
)
409406
c_low = self.low_interp(c, v, dt)
410407
c_high = self.high_interp(c, v, dt)
411-
c_left = c.shift(-1, axis).data
412-
c_right = c.shift(1, axis).data
413-
c_next_right = c.shift(2, axis).data
408+
c_left = c.shift(-1, dim).data
409+
c_right = c.shift(1, dim).data
410+
c_next_right = c.shift(2, dim).data
414411
pos_r = safe_div(c - c_left, c_right - c)
415412
neg_r = safe_div(c_next_right - c_right, c_right - c)
416413
pos_phi = self.limiter(pos_r).data
417414
neg_phi = self.limiter(neg_r).data
418-
u = v[axis]
415+
u = v[dim]
419416
phi = torch.where(u > 0, pos_phi, neg_phi)
420417
interpolated = c_low - (c_low - c_high) * phi
421418
c = GridVariable(interpolated.data, interpolation_offset, c.grid, c.bc)
@@ -455,8 +452,8 @@ def __init__(
455452
self,
456453
grid: Grid,
457454
offset: Tuple[float, ...],
458-
bc_c: boundaries.BoundaryConditions,
459-
bc_v: Tuple[boundaries.BoundaryConditions, ...],
455+
bc_c: Optional[BoundaryConditions] = None,
456+
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
460457
limiter: Optional[Callable] = None,
461458
) -> None:
462459
super().__init__()
@@ -538,13 +535,8 @@ def __init__(
538535
self,
539536
grid: Grid,
540537
offset=(0.5, 0.5),
541-
bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions(
542-
ndim=2
543-
),
544-
bc_v: Tuple[boundaries.BoundaryConditions, ...] = (
545-
boundaries.periodic_boundary_conditions(ndim=2),
546-
boundaries.periodic_boundary_conditions(ndim=2),
547-
),
538+
bc_c: Optional[BoundaryConditions] = None,
539+
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
548540
**kwargs,
549541
):
550542
super().__init__(grid, offset, bc_c, bc_v)
@@ -578,13 +570,8 @@ def __init__(
578570
self,
579571
grid: Grid,
580572
offset: Tuple[float, ...] = (0.5, 0.5),
581-
bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions(
582-
ndim=2
583-
),
584-
bc_v: Tuple[boundaries.BoundaryConditions, ...] = (
585-
boundaries.periodic_boundary_conditions(ndim=2),
586-
boundaries.periodic_boundary_conditions(ndim=2),
587-
),
573+
bc_c: Optional[BoundaryConditions] = None,
574+
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
588575
**kwargs,
589576
):
590577
super().__init__(grid, offset, bc_c, bc_v)
@@ -617,13 +604,8 @@ def __init__(
617604
self,
618605
grid: Grid,
619606
offset: Tuple[float, ...] = (0.5, 0.5),
620-
bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions(
621-
ndim=2
622-
),
623-
bc_v: Tuple[boundaries.BoundaryConditions, ...] = (
624-
boundaries.periodic_boundary_conditions(ndim=2),
625-
boundaries.periodic_boundary_conditions(ndim=2),
626-
),
607+
bc_c: Optional[BoundaryConditions] = None,
608+
bc_v: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
627609
limiter: Callable = van_leer_limiter,
628610
**kwargs,
629611
):
@@ -672,16 +654,21 @@ def __init__(
672654
self,
673655
grid: Grid,
674656
offsets: Tuple[Tuple[float, ...], ...] = ((1.0, 0.5), (0.5, 1.0)),
675-
bcs: Tuple[boundaries.BoundaryConditions, ...] = (
676-
boundaries.periodic_boundary_conditions(ndim=2),
677-
boundaries.periodic_boundary_conditions(ndim=2),
678-
),
657+
bcs: Optional[Tuple[boundaries.BoundaryConditions, ...]] = None,
679658
advect: type[nn.Module] = AdvectionVanLeer,
680659
limiter: Callable = van_leer_limiter,
681660
**kwargs,
682661
):
683662
super().__init__()
684-
663+
self.grid = grid
664+
self.offsets = offsets
665+
bcs = default(
666+
bcs,
667+
tuple(
668+
boundaries.periodic_boundary_conditions(ndim=grid.ndim)
669+
for _ in range(grid.ndim)
670+
),
671+
)
685672
self.advect = nn.ModuleList(
686673
advect(
687674
grid=grid,

0 commit comments

Comments
 (0)