Skip to content

Commit c5f842b

Browse files
committed
Update READMEs
1 parent d1cc89c commit c5f842b

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# Torch CFD
1+
# Computational Fluid Dynamics in PyTorch
22

3-
This is a native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd). The main changes are documented in the `README.md` under the [`torch_cfd` directory](torch_cfd/README.md). The biggest change is many routines that rely on the functional programming of Jax have been rewritten to be a more PyTorch-friendly tensor-in to tensor-out style.
3+
A native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd). The main changes are documented in the `README.md` under the [`torch_cfd` directory](torch_cfd). Biggest changes in all routines:
4+
- Routines that rely on the functional programming of Jax have been rewritten to be a more debugger-friendly PyTorch tensor-in-tensor-out style.
5+
- Functions and operators are in general implemented as `nn.Module`.
46

57
## Installation
68

@@ -9,11 +11,14 @@ pip install torch-cfd
911
```
1012

1113
## Contributions
12-
PR welcome. Current the port only includes:
14+
PR welcome. Currently, the port only includes:
1315
- Pseudospectral methods for vorticity which use anti-aliasing filtering techniques for non-linear terms to maintain stability.
14-
- Temporal discretization: Currently only RK4 temporal discretization, using explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion.
16+
- Temporal discretization: Currently only RK4 temporal discretization using explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion.
1517
- Boundary conditions: only periodic boundary conditions.
1618

1719
## Examples
1820
- Demos of different simulation setups:
19-
- [2D simulation with a psuedo-spectral solver](example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb)
21+
- [2D simulation with a pseudo-spectral solver](example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb)
22+
23+
## Acknowledgments
24+
SC appreciates the support from the National Science Foundation DMS-2309778.

torch_cfd/README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,31 @@
33
- [ ] add native PyTorch implementation for applying `torch.linalg` and `torch.fft` function directly on `GridArray`.
44
- [ ] add discrete Helmholtz decomposition in both spatial and spectral domains.
55
- [ ] adjust the function to act on `(batch, time, *spatial)` tensor, currently only `(*spatial)` is supported.
6+
- [x] add native vorticity computation, instead of taking FDM for pseudo-spectral.
67

78
## Changelog
9+
10+
### 0.0.4
11+
- The forcing functions are now implemented as `nn.Module` and utilize a wrapper decorator for the potential function.
12+
- Added some common time stepping schemes, additional ones that Jax-CFD did not have includes the commonly used Crank-Nicholson IMEX.
13+
- Combined the implementation for step size satisfying the CFL condition.
14+
15+
16+
### 0.0.1
817
- `grids.GridArray` is implemented as a subclass of `torch.Tensor`, not the original jax implentation uses the inheritance from `np.lib.mixins.NDArrayOperatorsMixin`. `__array_ufunc__()` is replaced by `__torch_function__()`.
9-
- The padding of `torch.nn.functional.pad()` is different from `jax.numpy.pad()`, PyTorch's pad starts from the last dimension, while Jax's pad starts from the first dimension. For example, `F.pad(x, (0, 0, 1, 0, 1, 1))` is equivalent to `jax.numpy.pad(x, ((1, 1), (1, 0), (0, 0)))` for an array of size `(*, t, h, w)`.
18+
- The padding of `torch.nn.functional.pad()` is different from `jax.numpy.pad()`, PyTorch's pad starts from the last dimension, while Jax's pad starts from the first dimension. For example, `F.pad(x, (0, 0, 1, 0, 1, 1))` is equivalent to `jax.numpy.pad(x, ((1, 1), (1, 0), (0, 0)))` for an array of size `(*, t, h, w)`.
19+
- A handy outer sum, which is usefully in getting the n-dimensional Laplacian in the frequency domain, is implemented as follows to replace `reduce(np.add.outer, eigenvalues)`
20+
```python
21+
def outer_sum(x: Union[List[Array], Tuple[Array]]) -> Array:
22+
"""
23+
Returns the outer sum of a list of one dimensional arrays
24+
Example:
25+
x = [a, b, c]
26+
out = a[..., None, None] + b[..., None] + c
27+
"""
28+
29+
def _sum(a, b):
30+
return a[..., None] + b
31+
32+
return reduce(_sum, x)
33+
```

0 commit comments

Comments
 (0)