Skip to content

Commit 6b459f3

Browse files
committed
Update README for 0.2.0
1 parent ffd636b commit 6b459f3

File tree

2 files changed

+97
-8
lines changed

2 files changed

+97
-8
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ This repository contains mainly two parts:
88

99
### Part I: a native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd)
1010
The main changes are documented in the `README.md` under the [`torch_cfd` directory](./torch_cfd/). The most significant changes in all routines include:
11-
- Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view immediate tensors using Data Wrangler in VS Code.
11+
- Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate tensors in VS Code.
1212
- Functions and operators are in general implemented as `nn.Module` like a factory template.
1313
- Jax-cfd's `funcutils.trajectory` function supports tracking only one field variable (vorticity or velocity). For this port, extra fields computation and tracking are made more accessible, such as time derivatives $\partial_t\boldsymbol{v}$ and PDE residual $R(\boldsymbol{v}):=\boldsymbol{f}-\partial_t \boldsymbol{v}-(\boldsymbol{v}\cdot\nabla)\boldsymbol{v} + \nu \Delta \boldsymbol{v}$.
14-
- All ops take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, n)`, which is similar to PyTorch behavior, not a single trajectory like Google's original Jax-CFD package.
14+
- All ops take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior, not a single trajectory like Google's original Jax-CFD package. The stencil operations generally starts from the last dimension, following `torch.nn.functional.pad`'s behavior.
1515

1616
### Part II: Spectral-Refiner: Neural Operator-Assisted Navier-Stokes Equations simulator.
1717
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE.
@@ -41,7 +41,7 @@ Data generation instructions are available in the [SFNO folder](./fno).
4141

4242
## Examples
4343
- Demos of different simulation setups:
44-
- [2D simulation with a pseudo-spectral solver](./examples/Kolmogrov2d_rk4_cn_forced_turbulence.ipynb)
44+
- [2D simulation with a pseudo-spectral solver](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
4545
- [2D simulation with a finite volume solver](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
4646
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
4747
- [Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
@@ -78,6 +78,8 @@ If you like to use `torch-cfd` please use the following [paper](https://arxiv.or
7878

7979
## Acknowledgments
8080
I am grateful for the support from [Long Chen (UC Irvine)](https://github.com/lyc102/ifem) and
81-
[Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. I also appreciate the support from the National Science Foundation (NSF) to junior researchers. I also want to thank the free A6000 credits at the SSE ML cluster from the University of Missouri.
81+
[Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. I also appreciate the support from the National Science Foundation (NSF) to junior researchers. I want to thank the free A6000 credits at the SSE ML cluster from the University of Missouri.
82+
83+
(Added after `0.2.0`) I also want to acknowledge that University of Missouri's OpenAI Enterprise API key. After from version `0.1.0`, I begin prompt existing codes in VSCode Copilot (using the OpenAI Enterprise API), which arguably significantly improve the efficiency on "porting->debugging->refactoring" cycle, e.g., Copilot helps design unittests and `.vscode/launch.json` for debugging. For details of how Copilot's suggestions on code refactoring, please see [README.md](./torch_cfd/README.md) in `torch_cfd` folder.
8284

8385
For individual paper's acknowledgment please see [here](./fno/README.md).

torch_cfd/README.md

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,99 @@
11
## TODO
22

3-
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray`.
4-
- [x] add discrete Helmholtz decomposition in both spatial and spectral domains.
5-
- [x] adjust the function to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported.
6-
- [x] add native vorticity computation, instead of taking FDM for pseudo-spectral.
3+
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.0).
4+
- [x] add discrete Helmholtz decomposition (pressure projection) in both spatial and spectral domains (added 0.0.1).
5+
- [x] adjust the functions and routines to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported (added for key routines in 0.0.1).
6+
- [x] add native FFT-based vorticity computation, instead of taking finite differences for pseudo-spectral (added in 0.0.4).
7+
- [ ] add no-slip boundary.
78

89
## Changelog
910

11+
### 0.2.0
12+
13+
After version `0.1.0`, I began prompt with existing codes in VSCode Copilot (using the OpenAI Enterprise API kindedly provided by UM), which arguably significantly improve the "porting->debugging->refactoring" cycle. I recorded some several good refactoring suggestions by GPT o4-mini and some by ***Claude Sonnet 3.7*** here. There were definitely over-complicated "poor" refactoring suggestions, which have been stashed after benchmarking. I found that Sonnet 3.7 is exceptionally good at providing templates for me to filling the details, when it is properly prompted with details of the functionality of current codes. Another highlight is that, based on the error or exception raised in the unittests, Sonnet 3.7 directly added configurations in `.vscode/launch.json`, saving me quite some time of copy-paste boilerplates then change by hand.
14+
15+
#### Major change: batch dimension for FVM
16+
The finite volume solver now accepts the batch dimension, some key updates include
17+
- Re-implemented flux computations of $(\boldsymbol{u}\cdot\nabla)\boldsymbol{u}$ as `nn.Module`. I originally implemented a tensor only version but did not quite work. Sonnet 3.7 provided a very good refactoring template after being given both the original code and my implementation.
18+
- Implemented a `_constant_pad_tensor` function to improve the behavior of `F.pad`, to help imposing non-homogeneous boundary conditions. It uses naturally ordered `pad` args (like Jax, unlike `F.pad`), while taking the batch dimension into consideration.
19+
- Changed the behavior of `u.shift` taking into consideration of batch dimension. In general these methods within the `bc` class or `GridVariable` starts from the last dimension instead of the first, e.g., `for dim in range(u.grid.ndim): ...` changes to `for dim in range(-u.grid.ndim, 0): ...`.
20+
21+
22+
#### Retaining only `GridVariable` class
23+
This refactoring is suggested by ***Claude Sonnet 3.7***. In [`grids.py`](./grids.py#442), following `numpy`'s practice (see updates notes in [0.0.1](#001)) in `np.lib.mixins.NDArrayOperatorsMixin`, I originally implemented two mixin classes, [`GridArrayOperatorsMixin`](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L304) and [`GridVariableOperatorsMixin`](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L616) using the same boilerplate to enable operations such as `v + u` or computing the upwinding flux for two `GridArray` instances:
24+
```python
25+
def _binary_method(name, op):
26+
def method(self, other):
27+
...
28+
method.__name__ = f"__{name}__"
29+
return method
30+
31+
class GridArrayOperatorsMixin:
32+
__slots__ = ()
33+
__lt__ = _binary_method("lt", operator.lt)
34+
...
35+
36+
@dataclasses.dataclass
37+
class GridArray(GridArrayOperatorsMixin):
38+
39+
@classmethod
40+
def __torch_function__(self, ufunc, types, args=(), kwargs=None):
41+
...
42+
```
43+
`GridVariable` is implemented largely the same recycling the codes. Note that `GridVariable` is only a container for `GridArray` that wraps boundary conditions of a field in it. Whereas`GridArray`, arguably being more vital in the whole scheme, determines an array `v`'s location by its `offset` (cell center or faces, or nodes) by Jax-CFD's original design. After a detailed prompt introducing each class's functions, after reading my workspace, **Sonnet 3.7** suggested introducing only a single `GridVariable`, while performing binary methods of two fields with the same offsets, the boundary conditions will be set to `None` if they don't share the same bc. This is already the case for some flux computations in the original `Jax-CFD` but implemented in a more hard-coded way. Now the implementation is much more concise and the boundary condition for flux computation is handled in automatically.
44+
45+
#### Adding a GridVectorBase class
46+
Yet again, ***Claude Sonnet 3.7*** gave an awesome refactoring advice here. In `0.1.0`'s `grids.py`, the vector field's wrappers recycles lots of [boilerplate codes I learned from numpy back in 0.0.1](https://github.com/scaomath/torch-cfd/blob/475c7385549225570b61d8c3dcf1d415d8977f19/torch_cfd/grids.py#L801). There codes are largely the same for `GridArray` and `GridVariable` to define their behaviors when performing `__add__` and `__mul__` with a scalar, etc:
47+
```python
48+
class GridArrayVector(tuple):
49+
def __new__(cls, arrays):
50+
...
51+
52+
def __add__(self, other):
53+
...
54+
55+
__radd__ = __add__
56+
57+
class GridVariableVector(tuple):
58+
def __new__(cls, variables):
59+
...
60+
61+
def __add__(self, other):
62+
# largely the same
63+
...
64+
__radd__ = __add__
65+
```
66+
The refactored code by Sonnet 3.7 is just amazing by cleverly exploiting the `classmethod` decorator and `super()`:
67+
```python
68+
from typing import TypeVar
69+
70+
class GridVectorBase(tuple, Generic[TypeVar("T")]):
71+
72+
def __new__(cls, v: Sequence[T]):
73+
if not all(isinstance(x, cls._element_type()) for x in v):
74+
raise TypeError
75+
return super().__new__(cls, v)
76+
77+
@classmethod
78+
def _element_type(cls):
79+
raise NotImplementedError
80+
81+
def __add__(self, other):
82+
...
83+
__radd__ = __add__
84+
85+
86+
class GridVariableVector(GridVectorBase[GridVariable]):
87+
@classmethod
88+
def _element_type(cls):
89+
return GridVariable
90+
91+
```
92+
93+
#### Unittests
94+
Another great feat by ***Sonnet 3.7*** is coming up with unittests using `absl.testing`'s parametrized testing. Based on [`test_grids.py`](tests/test_grids.py) I ported and tweaked by-hand example-wise, Sonnet 3.7 generated [corresponding tests using finite differences](tests/test_finite_differences.py). Even though "reasoning" regarding numerical PDE is sometimes wrong, for example, coming up with what would be shape after trimming the boundary for MAC grids variables, most are correctly formulated and helped figure out several bugs regarding the batch implementation for finite volume method.
95+
96+
1097
### 0.1.0
1198
- Implemented the FVM method on a staggered MAC grid (pressure on cell centers).
1299
- Added native PyTorch implementation for applying `torch.linalg` and `torch.fft` functions directly on `GridArray` and `GridVariable`.

0 commit comments

Comments
 (0)