Skip to content

Commit d6968f2

Browse files
committed
Updated README and changelog.
1 parent edea2bc commit d6968f2

File tree

4 files changed

+368
-352
lines changed

4 files changed

+368
-352
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ This repository contains mainly two parts:
99

1010
### Part I: a native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd)
1111
The main changes are documented in the `README.md` under the [`torch_cfd` directory](./torch_cfd/). The most significant changes in all routines include:
12-
- Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate tensors in VS Code debugger, set.
12+
- Supports for nonhomogenous boundary conditions, many routines in Jax-CFD only work with only periodic boundary.
13+
- Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate values in tensors in VS Code debugger, opposed to Jax's `JaxprTrace`.
14+
- All operations take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior. In the original Jax-CFD package, only a single trajectory is implemented. The stencil operators are changed to generally operate from the last dimension using negative indexing, following `torch.nn.functional.pad`'s behavior.
1315
- Functions and operators are in general implemented as `nn.Module` like a factory template.
14-
- Jax-cfd's `funcutils.trajectory` function supports tracking only one field variable (vorticity or velocity). For this port, extra fields computation and tracking are made more accessible, such as time derivatives $\partial_t\mathbf{u}_h$ and PDE residual $R(\mathbf{u}_h):=\mathbf{f}-\partial_t \mathbf{u}_h-(\mathbf{u}_h\cdot\nabla)\mathbf{u}_h + \nu \Delta \mathbf{u}_h$.
15-
- All ops take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior. In Google Research's original Jax-CFD package, only a single trajectory is implemented. The stencil operations generally starts from the last dimension using negative indexing, following `torch.nn.functional.pad`'s behavior.
16+
- Jax-CFD's `funcutils.trajectory` function supports tracking only one field variable (vorticity or velocity). in Torch-CFD, extra fields computation and tracking are more accessible and easier for user to add, such as time derivatives $\partial_t\mathbf{u}_h$ and PDE residual $R(\mathbf{u}_h):=\mathbf{f}-\partial_t \mathbf{u}_h-(\mathbf{u}_h\cdot\nabla)\mathbf{u}_h + \nu \Delta \mathbf{u}_h$.
17+
1618

1719
### Part II: Spectral-Refiner: Neural Operator-Assisted Navier-Stokes Equations simulator.
1820
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE.
@@ -43,8 +45,9 @@ Data generation instructions are available in the [SFNO folder](./fno).
4345

4446
## Examples
4547
- Demos of different simulation setups:
46-
- [2D simulation with a pseudo-spectral solver](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
47-
- [2D simulation with a finite volume solver](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
48+
- [2D Lid-driven cavity with a random field perturbation using finite volume](./examples/Lid-driven_cavity_rk4_fvm.ipynb)
49+
- [2D decaying isotropic turbulence using the pseudo-spectral method](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
50+
- [2D Kolmogorov flow using finite volume method](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
4851
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
4952
- [Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
5053
- [Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime.
@@ -60,8 +63,8 @@ The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of t
6063
PR welcome. Currently, the port of `torch-cfd` currently includes:
6164
- The pseudospectral method for vorticity uses anti-aliasing filtering techniques for nonlinear terms to maintain stability.
6265
- The finite volume method on a MAC grid for velocity, and using the projection scheme to impose the divergence free condition.
63-
- Temporal discretization: Currently only RK4 temporal discretization uses explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion.
64-
- Boundary conditions: only periodic boundary conditions.
66+
- Temporal discretization: Currently only RK4-family of marching schemes uses explicit time-stepping for advection, either implicit or explicit time-stepping for diffusion.
67+
- Boundary conditions: only periodic and Dirichlet boundary conditions.
6568

6669
## Reference
6770

examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb

Lines changed: 0 additions & 343 deletions
This file was deleted.

examples/McWilliams2d_rk4_spectral_istropic_turbulence.ipynb

Lines changed: 343 additions & 0 deletions
Large diffs are not rendered by default.

torch_cfd/README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
# TODO
22

3-
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.).
43
- [x] add discrete Helmholtz decomposition (pressure projection) in both spatial and spectral domains (added 0.0.1).
54
- [x] adjust the functions and routines to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported (added for key routines in 0.0.1).
65
- [x] add native FFT-based vorticity computation, instead of taking finite differences for pseudo-spectral (added in 0.0.4).
7-
- [ ] add no-slip boundary.
6+
- [x] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray` and `GridVariable` (added 0.1.0).
7+
- [x] add no-slip boundary (added in 0.2.2).
8+
- [ ] support for function-valued boundary conditions.
89

910
# Changelog
1011

12+
### 0.2.3
13+
- Major fixes: Jax-CFD routines that does not work for non-homogeneous boundary conditions are rewritten:
14+
- removed wrapping $\partial v/\partial t$ with $v$'s boundary condition (`explicit_terms_with_same_bcs` routine, which is wrong for nonhomogeneous bcs).
15+
- changed `pad` function to work with tuple padding inputs.
16+
- fixed `pad` behavior on `offset==1.5` functions.
17+
- fixed `bc.pad_all` behavior.`
18+
- added a [`_symmetric_pad_tensor`](./grids.py#1348) function to match the behavior of `np.pad` with mode `symmetric` (symmetric padding across the boundary, not just mirror).
19+
- changd the behavior of `pad_and_impose_bc` in `BoundaryCondition` class to correctly impose bc when ghost cells have to be presented, and added some tests.
20+
- advection module is completely refactored to as `nn.Module`, tests added for advection.
21+
- added `.norm` property and `__getitem__` for a `GridVariable`.
22+
- added `__repr__` for `Grid` and `GridVariable` for neater format when being printed.
23+
1124
### 0.2.0
1225

1326
After version `0.1.0`, I began prompt with existing codes in VSCode Copilot (using the OpenAI Enterprise API kindedly provided by UM), which arguably significantly improve the "porting->debugging->refactoring" cycle. I recorded some several good refactoring suggestions by GPT o4-mini and some by ***Claude Sonnet 3.7*** here. There were definitely over-complicated "poor" refactoring suggestions, which have been stashed after benchmarking. I found that Sonnet 3.7 is exceptionally good at providing templates for me to filling the details, when it is properly prompted with details of the functionality of current codes. Another highlight is that, based on the error or exception raised in the unittests, Sonnet 3.7 directly added configurations in `.vscode/launch.json`, saving me quite some time of copy-paste boilerplates then change by hand.

0 commit comments

Comments
 (0)