Skip to content

Commit 0769349

Browse files
committed
update nonhomogeneous bcs
1 parent 5f19c26 commit 0769349

File tree

5 files changed

+958
-151
lines changed

5 files changed

+958
-151
lines changed

torch_cfd/advection.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818

1919
import math
20-
from typing import Callable, Optional, Tuple
2120
from functools import partial
21+
from typing import Callable, Optional, Tuple
22+
2223
import torch
2324
import torch.nn as nn
2425

@@ -31,9 +32,11 @@
3132
GridVariableVector = grids.GridVariableVector
3233
FluxInterpFn = Callable[[GridVariable, GridVariableVector, float], GridVariable]
3334

35+
3436
def default(value, d):
3537
return d if value is None else value
3638

39+
3740
def safe_div(x, y, default_numerator=1):
3841
"""Safe division of `Array`'s."""
3942
return x / torch.where(y != 0, y, default_numerator)
@@ -47,7 +50,7 @@ def van_leer_limiter(r):
4750
class Upwind(nn.Module):
4851
"""Upwind interpolation module for scalar fields.
4952
50-
Upwind interpolation of a scalar field `c` to a
53+
Upwind interpolation of a scalar field `c` to a
5154
target offset based on the velocity field `v`. The interpolation is done axis-wise and uses the upwind scheme where values are taken from upstream cells based on the flow direction.
5255
5356
The module identifies the interpolation axis (must be a single axis) and selects values from the previous cell for positive velocity or the next cell for negative velocity along that axis.
@@ -73,7 +76,9 @@ def __init__(
7376
):
7477
super().__init__()
7578
self.grid = grid
76-
self.target_offset = target_offset # this is the offset to which we will interpolate c
79+
self.target_offset = (
80+
target_offset # this is the offset to which we will interpolate c
81+
)
7782

7883
def forward(
7984
self,
@@ -113,7 +118,9 @@ def forward(
113118
ceil = int(math.ceil(offset_delta))
114119
c_floor = c.shift(floor, dim).data
115120
c_ceil = c.shift(ceil, dim).data
116-
return GridVariable(torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc)
121+
return GridVariable(
122+
torch.where(u.data > 0, c_floor, c_ceil), self.target_offset, c.grid, c.bc
123+
)
117124

118125

119126
class LaxWendroff(nn.Module):
@@ -132,7 +139,7 @@ class LaxWendroff(nn.Module):
132139
Lax-Wendroff method can be used to form monotonic schemes when augmented with
133140
a flux limiter. See https://en.wikipedia.org/wiki/Flux_limiter
134141
135-
Args:
142+
Args:
136143
grid: The computational grid on which interpolation is performed, only used for step.
137144
offset: Target offset to which scalar fields will be interpolated during
138145
forward passes. Target offset have the same length as `c.offset` in forward() and differ in at most one entry.
@@ -147,8 +154,8 @@ def __init__(
147154
target_offset: Tuple[float, ...],
148155
):
149156
super().__init__()
150-
self.grid = grid
151-
self.target_offset = target_offset
157+
self.grid = grid
158+
self.target_offset = target_offset
152159

153160
def forward(
154161
self,
@@ -189,8 +196,10 @@ def forward(
189196
c_ceil = c.shift(ceil, dim).data
190197
pos = c_floor + 0.5 * (1 - courant) * (c_ceil - c_floor)
191198
neg = c_ceil - 0.5 * (1 + courant) * (c_ceil - c_floor)
192-
return GridVariable(torch.where(u.data > 0, pos, neg),
193-
self.target_offset, c.grid, c.bc)
199+
return GridVariable(
200+
torch.where(u.data > 0, pos, neg), self.target_offset, c.grid, c.bc
201+
)
202+
194203

195204
class AdvectAligned(nn.Module):
196205
"""
@@ -277,16 +286,20 @@ def forward(self, cs: GridVariableVector, v: GridVariableVector) -> GridVariable
277286
)
278287

279288
# Compute flux: cu
280-
# if cs and v have different boundary conditions,
289+
# if cs and v have different boundary conditions,
281290
# flux's bc will become None
282291
flux = GridVariableVector(tuple(c * u for c, u in zip(cs, v)))
283292

284293
# wrap flux with boundary conditions to flux if not periodic
285294
# flux = GridVariableVector(
286295
# tuple(bc.impose_bc(f) for f, bc in zip(flux, self.flux_bcs))
287296
# )
288-
flux = GridVariableVector(tuple(GridVariable(f.data, offset, f.grid, bc) for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)))
289-
297+
flux = GridVariableVector(
298+
tuple(
299+
GridVariable(f.data, offset, f.grid, bc)
300+
for f, offset, bc in zip(flux, self.offsets, self.flux_bcs)
301+
)
302+
)
290303

291304
# Return negative divergence of flux
292305
# after taking divergence the bc becomes None
@@ -335,12 +348,12 @@ class TVDInterpolation(nn.Module):
335348
http://www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf
336349
337350
Args:
338-
target_offset: offset to which we will interpolate `c`.
351+
target_offset: offset to which we will interpolate `c`.
339352
Must have the same length as `c.offset` and differ in at most one entry. This offset should interface as other interpolation methods (take `c`, `v` and `dt` arguments and return value of `c` at offset `offset`).
340-
limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients.
353+
limiter: flux limiter function that evaluates the portion of the correction (high_accuracy - low_accuracy) to add to low_accuracy solution based on the ratio of the consecutive gradients.
341354
Takes array as input and return array of weights. For more details see:
342355
https://en.wikipedia.org/wiki/Flux_limiter
343-
356+
344357
"""
345358

346359
def __init__(
@@ -353,8 +366,16 @@ def __init__(
353366
):
354367
super().__init__()
355368
self.grid = grid
356-
self.low_interp = Upwind(grid, target_offset=target_offset) if low_interp is None else low_interp
357-
self.high_interp = LaxWendroff(grid, target_offset=target_offset) if high_interp is None else high_interp
369+
self.low_interp = (
370+
Upwind(grid, target_offset=target_offset)
371+
if low_interp is None
372+
else low_interp
373+
)
374+
self.high_interp = (
375+
LaxWendroff(grid, target_offset=target_offset)
376+
if high_interp is None
377+
else high_interp
378+
)
358379
self.limiter = limiter
359380
self.target_offset = target_offset
360381

@@ -370,8 +391,9 @@ def forward(
370391
v: GridVariableVector representing the velocity field.
371392
dt: Time step size (not used in this interpolation).
372393
373-
Returns:
374-
Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method."""
394+
Returns:
395+
Interpolated scalar field c to a target offset using Van Leer flux limiting, which uses a combination of high and low order methods to produce monotonic interpolation method.
396+
"""
375397
for axis, axis_offset in enumerate(self.target_offset):
376398
interpolation_offset = tuple(
377399
[
@@ -420,7 +442,7 @@ class AdvectionBase(nn.Module):
420442
4. Set the boundary condition on flux, which is inhereited from `c`.
421443
5. Return the negative divergence of the flux.
422444
423-
Args:
445+
Args:
424446
grid: Grid.
425447
offset: the current scalar field `c` to be advected.
426448
bc_c: Boundary conditions for the scalar field `c`.
@@ -429,13 +451,14 @@ class AdvectionBase(nn.Module):
429451
430452
"""
431453

432-
def __init__(self,
433-
grid: Grid,
434-
offset: Tuple[float, ...],
435-
bc_c: boundaries.BoundaryConditions,
436-
bc_v: Tuple[boundaries.BoundaryConditions, ...],
437-
limiter: Optional[Callable] = None,
438-
) -> None:
454+
def __init__(
455+
self,
456+
grid: Grid,
457+
offset: Tuple[float, ...],
458+
bc_c: boundaries.BoundaryConditions,
459+
bc_v: Tuple[boundaries.BoundaryConditions, ...],
460+
limiter: Optional[Callable] = None,
461+
) -> None:
439462
super().__init__()
440463
self.grid = grid
441464
self.offset = offset if offset is not None else (0.5,) * grid.ndim
@@ -450,19 +473,24 @@ def __init__(self,
450473
),
451474
)
452475
self.advect_aligned = AdvectAligned(
453-
grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets)
454-
self._flux_interp = nn.ModuleList() # placeholder
455-
self._velocity_interp = nn.ModuleList() # placeholder
476+
grid=grid, bcs_c=(bc_c, bc_c), bcs_v=bc_v, offsets=self.target_offsets
477+
)
478+
self._flux_interp = nn.ModuleList() # placeholder
479+
self._velocity_interp = nn.ModuleList() # placeholder
456480

457481
def __post_init__(self):
458482
assert len(self._flux_interp) == len(self.target_offsets)
459483
assert len(self._velocity_interp) == len(self.target_offsets)
460484

461485
for dim, interp in enumerate(self._flux_interp):
462-
assert interp.target_offset == self.target_offsets[dim], f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."
463-
486+
assert (
487+
interp.target_offset == self.target_offsets[dim]
488+
), f"Expected flux interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."
489+
464490
for dim, interp in enumerate(self._velocity_interp):
465-
assert interp.target_offset == self.target_offsets[dim], f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."
491+
assert (
492+
interp.target_offset == self.target_offsets[dim]
493+
), f"Expected velocity interpolation for dimension {dim} to have target offset {self.target_offsets[dim]}, but got {interp.target_offset}."
466494

467495
def flux_interp(
468496
self,
@@ -478,7 +506,9 @@ def velocity_interp(
478506
self, v: GridVariableVector, *args, **kwargs
479507
) -> GridVariableVector:
480508
"""Interpolate the velocity field `v` to the target offsets."""
481-
return GridVariableVector(tuple(interp(u) for interp, u in zip(self._velocity_interp, v)))
509+
return GridVariableVector(
510+
tuple(interp(u) for interp, u in zip(self._velocity_interp, v))
511+
)
482512

483513
def forward(
484514
self,
@@ -490,10 +520,10 @@ def forward(
490520
Args:
491521
c: the scalar field to be advected.
492522
v: representing the velocity field.
493-
523+
494524
Returns:
495525
An GridVariable containing the time derivative of `c` due to advection by `v`.
496-
526+
497527
"""
498528
aligned_v = self.velocity_interp(v)
499529

@@ -507,7 +537,7 @@ class AdvectionLinear(AdvectionBase):
507537
def __init__(
508538
self,
509539
grid: Grid,
510-
offset = (0.5, 0.5),
540+
offset=(0.5, 0.5),
511541
bc_c: boundaries.BoundaryConditions = boundaries.periodic_boundary_conditions(
512542
ndim=2
513543
),
@@ -520,12 +550,14 @@ def __init__(
520550
super().__init__(grid, offset, bc_c, bc_v)
521551
self._flux_interp = nn.ModuleList(
522552
LinearInterpolation(grid, target_offset=offset)
523-
for offset in self.target_offsets)
553+
for offset in self.target_offsets
554+
)
524555

525556
self._velocity_interp = nn.ModuleList(
526557
LinearInterpolation(grid, target_offset=offset)
527-
for offset in self.target_offsets)
528-
558+
for offset in self.target_offsets
559+
)
560+
529561

530562
class AdvectionUpwind(AdvectionBase):
531563
"""
@@ -535,9 +567,9 @@ class AdvectionUpwind(AdvectionBase):
535567
- flux_interp: a Upwind interpolation for each component of the velocity field `v`.
536568
- velocity_interp: a LinearInterpolation for each component of the velocity field `v`.
537569
538-
Args:
570+
Args:
539571
- offset: current offset of the scalar field `c` to be advected.
540-
572+
541573
Returns:
542574
Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces.
543575
"""
@@ -557,8 +589,7 @@ def __init__(
557589
):
558590
super().__init__(grid, offset, bc_c, bc_v)
559591
self._flux_interp = nn.ModuleList(
560-
Upwind(grid, target_offset=offset)
561-
for offset in self.target_offsets
592+
Upwind(grid, target_offset=offset) for offset in self.target_offsets
562593
)
563594

564595
self._velocity_interp = nn.ModuleList(
@@ -575,9 +606,9 @@ class AdvectionVanLeer(AdvectionBase):
575606
- flux_interp: a TVDInterpolation with Upwind and LaxWendroff methods
576607
- velocity_interp: a LinearInterpolation for each component of the velocity field `v`.
577608
578-
Args:
609+
Args:
579610
- offset: current offset of the scalar field `c` to be advected.
580-
611+
581612
Returns:
582613
Aligned advection of the scalar field `c` by the velocity field `v` using the target offsets on the control volume faces.
583614
"""
@@ -611,6 +642,7 @@ def __init__(
611642
for offset, bc in zip(self.target_offsets, bc_v)
612643
)
613644

645+
614646
class ConvectionVector(nn.Module):
615647
"""Computes convection of a vector field `v` by the velocity field `u`.
616648

0 commit comments

Comments
 (0)