Skip to content

Commit 0653ee9

Browse files
authored
0.2.1: fixed boundary values imposing bugs from Jax-cfd (#3)
* updated requirements and added a test workflow * fixed boundary values imposing bugs from Jax-cfd
1 parent 6b459f3 commit 0653ee9

File tree

11 files changed

+1020
-98
lines changed

11 files changed

+1020
-98
lines changed

.github/workflows/pytest.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Run Tests 🧪
2+
3+
on:
4+
push:
5+
tags:
6+
- '*'
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
name: Run pytest
12+
runs-on: ubuntu-latest
13+
14+
steps:
15+
- uses: actions/checkout@v2
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: "3.10"
21+
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
pip install -r requirements.txt
26+
27+
- name: Extract tag name
28+
id: tag
29+
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
30+
31+
- name: Update version in setup.py
32+
run: >-
33+
sed -i "s/{{VERSION_PLACEHOLDER}}/${{ steps.tag.outputs.TAG_NAME }}/g" setup.py
34+
35+
- name: Install build
36+
run: python -m pip install build
37+
38+
- name: Run pytest
39+
run: |
40+
pytest --pyargs torch_cfd --verbose

.github/workflows/python-publish.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@ jobs:
2525
with:
2626
python-version: "3.10"
2727

28-
- name: Install pip
29-
run: python -m pip install --upgrade pip
30-
3128
- name: Install dependencies
32-
run: pip install -r requirements.txt
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install -r requirements.txt
3332
3433
- name: Extract tag name
3534
id: tag

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,6 @@ If you like to use `torch-cfd` please use the following [paper](https://arxiv.or
8080
I am grateful for the support from [Long Chen (UC Irvine)](https://github.com/lyc102/ifem) and
8181
[Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. I also appreciate the support from the National Science Foundation (NSF) to junior researchers. I want to thank the free A6000 credits at the SSE ML cluster from the University of Missouri.
8282

83-
(Added after `0.2.0`) I also want to acknowledge that University of Missouri's OpenAI Enterprise API key. After from version `0.1.0`, I begin prompt existing codes in VSCode Copilot (using the OpenAI Enterprise API), which arguably significantly improve the efficiency on "porting->debugging->refactoring" cycle, e.g., Copilot helps design unittests and `.vscode/launch.json` for debugging. For details of how Copilot's suggestions on code refactoring, please see [README.md](./torch_cfd/README.md) in `torch_cfd` folder.
83+
(Added after `0.2.0`) I also want to acknowledge that University of Missouri's OpenAI Enterprise API key. After version `0.1.0`, I began prompt in VSCode Copilot with existing codes (using the OpenAI Enterprise API), which arguably significantly improve the efficiency on "porting->debugging->refactoring" cycle, e.g., Copilot helps design unittests and `.vscode/launch.json` for debugging. For details of how Copilot's suggestions on code refactoring, please see [README.md](./torch_cfd/README.md) in `torch_cfd` folder.
8484

8585
For individual paper's acknowledgment please see [here](./fno/README.md).

requirements.txt

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
numpy>=2.2.0
2-
torch>=2.5.0
3-
xarray>=2025.3.1
4-
tqdm>=4.62.0
5-
einops>=0.8.0
6-
dill>=0.4.0
7-
matplotlib>=3.5.0
8-
seaborn>=0.13.0
1+
absl_py==2.2.2
2+
dill==0.4.0
3+
einops==0.8.1
4+
h5py==3.13.0
5+
matplotlib==3.10.3
6+
numpy==2.2.6
7+
plotly==6.0.1
8+
psutil==7.0.0
9+
pytest==8.3.5
10+
scipy==1.15.3
11+
seaborn==0.13.2
12+
tensordict==0.7.2
13+
torch==2.6.0
14+
tqdm==4.67.1
15+
xarray==2025.3.1

torch_cfd/README.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
1-
## TODO
1+
# TODO
22

3-
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.0).
3+
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.).
44
- [x] add discrete Helmholtz decomposition (pressure projection) in both spatial and spectral domains (added 0.0.1).
55
- [x] adjust the functions and routines to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported (added for key routines in 0.0.1).
66
- [x] add native FFT-based vorticity computation, instead of taking finite differences for pseudo-spectral (added in 0.0.4).
77
- [ ] add no-slip boundary.
88

9-
## Changelog
9+
# Changelog
1010

1111
### 0.2.0
1212

1313
After version `0.1.0`, I began prompt with existing codes in VSCode Copilot (using the OpenAI Enterprise API kindedly provided by UM), which arguably significantly improve the "porting->debugging->refactoring" cycle. I recorded some several good refactoring suggestions by GPT o4-mini and some by ***Claude Sonnet 3.7*** here. There were definitely over-complicated "poor" refactoring suggestions, which have been stashed after benchmarking. I found that Sonnet 3.7 is exceptionally good at providing templates for me to filling the details, when it is properly prompted with details of the functionality of current codes. Another highlight is that, based on the error or exception raised in the unittests, Sonnet 3.7 directly added configurations in `.vscode/launch.json`, saving me quite some time of copy-paste boilerplates then change by hand.
1414

1515
#### Major change: batch dimension for FVM
1616
The finite volume solver now accepts the batch dimension, some key updates include
17-
- Re-implemented flux computations of $(\boldsymbol{u}\cdot\nabla)\boldsymbol{u}$ as `nn.Module`. I originally implemented a tensor only version but did not quite work. Sonnet 3.7 provided a very good refactoring template after being given both the original code and my implementation.
17+
- Re-implemented flux computations of $(\boldsymbol{u}\cdot\nabla)\boldsymbol{u}$ as `nn.Module`. I originally implemented a tensor only version but did not quite work by pre-assigning `target_offsets`, which was buggy for the second component of the velocity. Sonnet 3.7 provided a very good refactoring template after being given both the original code and my implementation, after which I pretty much just fill in the blanks in [`advection.py`](./advection.py). Later I found out the bug was pretty stupid on my side, from
18+
```python
19+
for i in range(2):
20+
u = v[i]
21+
for j in range(2):
22+
v[i] = flux_interpolation(u, v[j]) # offset is updated here
23+
```
24+
to
25+
```python
26+
# this is gonna be buggy of course because the offset alignment will go wrong
27+
# the target_offsets are looped inside flux_interpolation of AdvectionVanLeer
28+
for offset in target_offsets:
29+
u = v[i]; u.offset = offset
30+
for j in range(2):
31+
v[i] = flux_interpolation(u, v[j])
32+
v[i].offset = offset
33+
```
34+
The fixed version that loops in `__call__` of [`AdvectionVanLeer` class is here](./advection.py#L451).
1835
- Implemented a `_constant_pad_tensor` function to improve the behavior of `F.pad`, to help imposing non-homogeneous boundary conditions. It uses naturally ordered `pad` args (like Jax, unlike `F.pad`), while taking the batch dimension into consideration.
1936
- Changed the behavior of `u.shift` taking into consideration of batch dimension. In general these methods within the `bc` class or `GridVariable` starts from the last dimension instead of the first, e.g., `for dim in range(u.grid.ndim): ...` changes to `for dim in range(-u.grid.ndim, 0): ...`.
2037

torch_cfd/boundaries.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535

3636
class Padding:
37-
MIRROR = "mirror"
38-
EXTEND = "extend"
37+
MIRROR = "reflect"
38+
EXTEND = "replicate"
3939

4040

4141
@dataclasses.dataclass(init=False, frozen=True)
@@ -71,30 +71,17 @@ def __init__(
7171
def shift(
7272
self,
7373
u: GridVariable,
74-
offset: float,
74+
offset: int,
7575
dim: int,
7676
) -> GridVariable:
7777
"""
7878
A fallback function to make the implementation back-compatible
7979
see grids.shift
80+
bc.shift(u, offset, dim) overrides u.bc
81+
grids.shift(u, offset, dim) keeps u.bc
8082
"""
81-
return grids.shift(u, offset, dim)
83+
return grids.shift(u, offset, dim, self)
8284

83-
def _count_bc_components(self) -> int:
84-
"""Counts the number of components in the boundary conditions.
85-
86-
Returns:
87-
The number of components in the boundary conditions.
88-
"""
89-
count = 0
90-
ndim = len(self.types)
91-
for axis in range(ndim): # ndim
92-
if len(self.types[axis]) != 2:
93-
raise ValueError(
94-
f"Boundary conditions for axis {axis} must have two values got {len(self.types[axis])}."
95-
)
96-
count += len(self.types[axis])
97-
return count
9885

9986
def _is_aligned(self, u: GridVariable, dim: int) -> bool:
10087
"""Checks if array u contains all interior domain information.
@@ -240,15 +227,15 @@ def pad_and_impose_bc(
240227
"""
241228
if offset_to_pad_to is None:
242229
offset_to_pad_to = u.offset
243-
for axis in range(u.grid.ndim):
244-
_ = self._is_aligned(u, axis)
245-
if self.types[axis][0] == BCType.DIRICHLET and math.isclose(
246-
u.offset[axis], 1.0
247-
):
248-
if math.isclose(offset_to_pad_to[axis], 1.0):
249-
u = grids.pad(u, 1, axis, mode=mode)
250-
elif math.isclose(offset_to_pad_to[axis], 0.0):
251-
u = grids.pad(u, -1, axis, mode=mode)
230+
for axis in range(-u.grid.ndim, 0):
231+
_ = self._is_aligned(u, axis)
232+
if self.types[axis][0] == BCType.DIRICHLET and math.isclose(
233+
u.offset[axis], 1.0
234+
):
235+
if math.isclose(offset_to_pad_to[axis], 1.0):
236+
u = grids.pad(u, 1, axis, self)
237+
elif math.isclose(offset_to_pad_to[axis], 0.0):
238+
u = grids.pad(u, -1, axis, self)
252239
return GridVariable(u.data, u.offset, u.grid, self)
253240

254241
def impose_bc(self, u: GridVariable) -> GridVariable:
@@ -265,7 +252,8 @@ def impose_bc(self, u: GridVariable) -> GridVariable:
265252
"""
266253
offset = u.offset
267254
u = self.trim_boundary(u)
268-
return self.pad_and_impose_bc(u, offset)
255+
u = self.pad_and_impose_bc(u, offset)
256+
return u
269257

270258

271259
class HomogeneousBoundaryConditions(ConstantBoundaryConditions):
@@ -383,6 +371,22 @@ def periodic_and_neumann_boundary_conditions(
383371
)
384372

385373

374+
def _count_bc_components(bc: BoundaryConditions) -> int:
375+
"""Counts the number of components in the boundary conditions.
376+
377+
Returns:
378+
The number of components in the boundary conditions.
379+
"""
380+
count = 0
381+
ndim = len(bc.types)
382+
for axis in range(ndim): # ndim
383+
if len(bc.types[axis]) != 2:
384+
raise ValueError(
385+
f"Boundary conditions for axis {axis} must have two values got {len(bc.types[axis])}."
386+
)
387+
count += len(bc.types[axis])
388+
return count
389+
386390
def consistent_boundary_conditions_grid(
387391
grid, *arrays: GridVariable
388392
) -> Tuple[GridVariable, ...]:
@@ -391,7 +395,7 @@ def consistent_boundary_conditions_grid(
391395
"""
392396
bc_counts = []
393397
for array in arrays:
394-
bc_counts.append(array.bc._count_bc_components())
398+
bc_counts.append(_count_bc_components(array.bc))
395399
bc_count = bc_counts[0]
396400
if any(bc_counts[i] != bc_count for i in range(1, len(bc_counts))):
397401
raise Exception("Boundary condition counts are inconsistent")

torch_cfd/fvm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ class ProjectionExplicitODE(nn.Module):
4545
u <- pressure_projection(u)
4646
"""
4747

48-
def explicit_terms(self, *, u):
48+
def explicit_terms(self, *args, **kwargs) -> GridVariableVector:
4949
"""
5050
Explicit forcing term as du/dt.
51-
* allows extra arguments to be passed.
5251
"""
5352
raise NotImplementedError
5453

55-
def pressure_projection(self, *, u):
54+
def pressure_projection(self, *args, **kwargs) -> Tuple[GridVariableVector, GridVariable]:
55+
"""Pressure projection step."""
5656
raise NotImplementedError
5757

5858
def forward(self, u: GridVariableVector, dt: float) -> GridVariableVector:

torch_cfd/grids.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def clone(self, *args, **kwargs):
856856
return super().clone(*args, **kwargs)
857857

858858

859-
def shift(u: GridVariable, offset: int, dim: int) -> GridVariable:
859+
def shift(u: GridVariable, offset: int, dim: int, bc: Optional[BoundaryConditions] = None) -> GridVariable:
860860
"""Shift a GridVariable by `offset`.
861861
862862
Args:
@@ -869,12 +869,12 @@ def shift(u: GridVariable, offset: int, dim: int) -> GridVariable:
869869
`u.offset + offset`.
870870
"""
871871
dim = -u.grid.ndim + dim if dim >= 0 else dim
872-
padded = pad(u, offset, dim)
872+
padded = pad(u, offset, dim, bc)
873873
trimmed = trim(padded, -offset, dim)
874874
return trimmed
875875

876876

877-
def pad(u: GridVariable, width: int, dim: int) -> GridVariable:
877+
def pad(u: GridVariable, width: int, dim: int, bc: Optional[BoundaryConditions] = None) -> GridVariable:
878878
"""Pad a GridVariable by `padding`.
879879
880880
Important: the original _pad in jax_cfd makes no sense past 1 ghost cell for nonperiodic boundaries.
@@ -892,11 +892,14 @@ def pad(u: GridVariable, width: int, dim: int) -> GridVariable:
892892
Note:
893893
the padding removes the boundary conditions, so u.bc is set to None.
894894
"""
895+
assert not (u.bc is None and bc is None), "Both u.bc and bc cannot be None"
896+
bc = bc if bc is not None else u.bc
897+
895898
if width < 0: # pad lower boundary
896-
bc_type = u.bc.types[dim][0]
899+
bc_type = bc.types[dim][0]
897900
padding = (-width, 0)
898901
else: # pad upper boundary
899-
bc_type = u.bc.types[dim][1]
902+
bc_type = bc.types[dim][1]
900903
padding = (0, width)
901904

902905
full_padding = [(0, 0)] * u.grid.ndim
@@ -913,20 +916,20 @@ def pad(u: GridVariable, width: int, dim: int) -> GridVariable:
913916
if bc_type == BCType.PERIODIC:
914917
# self.values are ignored here
915918
data = expand_dims_pad(u.data, full_padding, mode="circular")
916-
return GridVariable(data, tuple(new_offset), u.grid)
919+
return GridVariable(data, tuple(new_offset), u.grid, bc)
917920
elif bc_type == BCType.DIRICHLET:
918921
if math.isclose(u.offset[dim] % 1, 0.5): # cell center
919922
# make the linearly interpolated value equal to the boundary by setting
920923
# the padded values to the negative symmetric values
921924
data = 2 * expand_dims_pad(
922-
u.data, full_padding, mode="constant", constant_values=u.bc._values
923-
) - expand_dims_pad(u.data, full_padding, mode="reflect")
924-
return GridVariable(data, tuple(new_offset), u.grid)
925+
u.data, full_padding, mode="constant", constant_values=bc._values
926+
) - expand_dims_pad(u.data, full_padding, mode="replicate")
927+
return GridVariable(data, tuple(new_offset), u.grid, bc)
925928
elif math.isclose(u.offset[dim] % 1, 0): # cell edge
926929
data = expand_dims_pad(
927-
u.data, full_padding, mode="constant", constant_values=u.bc._values
930+
u.data, full_padding, mode="constant", constant_values=bc._values
928931
)
929-
return GridVariable(data, tuple(new_offset), u.grid)
932+
return GridVariable(data, tuple(new_offset), u.grid, bc)
930933
else:
931934
raise ValueError(
932935
"expected the new offset to be an edge or cell center, got "
@@ -943,19 +946,19 @@ def pad(u: GridVariable, width: int, dim: int) -> GridVariable:
943946
else:
944947
# When the data is cell-centered, computes the backward difference.
945948
# When the data is on cell edges, boundary is set such that
946-
# (u_last_interior - u_boundary)/grid_step = neumann_bc_value.
949+
# (u_boundary - u_last_interior)/grid_step = neumann_bc_value (fixed from Jax-cfd).
950+
# note: Jax-cfd implementation was wrong, Neumann BC is \nabla u \cdot exterior normal, the order is reversed
947951
data = expand_dims_pad(
948952
u.data, full_padding, mode="replicate"
949953
) + u.grid.step[dim] * (
950-
expand_dims_pad(u.data, full_padding, mode="constant")
951-
- expand_dims_pad(
954+
expand_dims_pad(
952955
u.data,
953956
full_padding,
954957
mode="constant",
955-
constant_values=u.bc._values,
956-
)
958+
constant_values=bc._values,
959+
) - expand_dims_pad(u.data, full_padding, mode="constant")
957960
)
958-
return GridVariable(data, tuple(new_offset), u.grid)
961+
return GridVariable(data, tuple(new_offset), u.grid, bc)
959962

960963
else:
961964
raise ValueError("invalid boundary type")

0 commit comments

Comments
 (0)