Skip to content

Commit 808f874

Browse files
authored
Merge pull request #21 from tumaer/upgrade_to_jax_062
Upgrade to jax>=0.6.2
2 parents 4c6f978 + 090bdea commit 808f874

23 files changed

+2541
-1498
lines changed

.github/workflows/publish.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ jobs:
1010
runs-on: ubuntu-latest
1111

1212
steps:
13-
- uses: actions/checkout@v4
13+
- uses: actions/checkout@v5
1414
- name: Set up Python
15-
uses: actions/setup-python@v4
15+
uses: actions/setup-python@v6
1616
with:
17-
python-version: '3.10'
17+
python-version: '3.12'
1818
- name: Install Poetry
1919
run: |
2020
python -m pip install --upgrade pip

.github/workflows/ruff.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@ jobs:
44
ruff:
55
runs-on: ubuntu-latest
66
steps:
7-
- uses: actions/checkout@v3
8-
- uses: chartboost/ruff-action@v1
7+
- uses: actions/checkout@v5
8+
- uses: astral-sh/ruff-action@v3
9+
- run: ruff check --fix
10+
- run: ruff format

.github/workflows/tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
python-version: ["3.10"] # Add more Python versions here, e.g. "3.9", "3.11"
18+
python-version: ["3.10", "3.12"]
1919

2020
steps:
21-
- uses: actions/checkout@v4
21+
- uses: actions/checkout@v5
2222
- name: Set up Python ${{ matrix.python-version }}
23-
uses: actions/setup-python@v4
23+
uses: actions/setup-python@v6
2424
with:
2525
python-version: ${{ matrix.python-version }}
2626
- name: Install Poetry
@@ -35,7 +35,7 @@ jobs:
3535
run: |
3636
.venv/bin/pytest --cov-report=xml
3737
- name: Upload coverage report to Codecov
38-
uses: codecov/codecov-action@v4
38+
uses: codecov/codecov-action@v5
3939
with:
4040
token: ${{ secrets.CODECOV_TOKEN }}
4141
file: ./coverage.xml

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ exclude: |
88
)$
99
repos:
1010
- repo: https://github.com/pre-commit/pre-commit-hooks
11-
rev: v4.4.0
11+
rev: v6.0.0
1212
hooks:
1313
- id: check-merge-conflict
1414
- id: check-added-large-files
@@ -18,8 +18,8 @@ repos:
1818
- id: check-yaml
1919
- id: requirements-txt-fixer
2020
- repo: https://github.com/astral-sh/ruff-pre-commit
21-
rev: 'v0.1.8'
21+
rev: 'v0.14.0'
2222
hooks:
23-
- id: ruff
23+
- id: ruff-check
2424
args: [ --fix ]
2525
- id: ruff-format

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@
3131
Install the `jax-sph` library from PyPI as
3232

3333
```bash
34-
python3.10 -m venv venv
34+
python3.12 -m venv venv
3535
source venv/bin/activate
3636
pip install jax-sph
3737
```
3838

3939
By default `jax-sph` is installed without GPU support. If you have a CUDA-capable GPU, follow the instructions in the [GPU support](#gpu-support) section.
4040

4141
### Clone
42-
We recommend using a `poetry` or `python3-venv` environment.
42+
We recommend using a `poetry` or `python3-venv` environment. Last tested with python3.10 + jax[cpu]==0.6.2 and python3.12 + jax[cpu]==0.7.2.
4343

4444
**Using Poetry**
4545
```bash
@@ -51,7 +51,7 @@ Later, you just need to `source .venv/bin/activate` to activate the environment.
5151

5252
**Using `python3-venv`**
5353
```bash
54-
python3 -m venv venv
54+
python3.12 -m venv venv
5555
source venv/bin/activate
5656
pip install -r requirements.txt
5757
pip install -e . # to install jax_sph in interactive mode
@@ -62,7 +62,7 @@ Later, you just need to `source venv/bin/activate` to activate the environment.
6262
If you want to use a CUDA GPU, you first need a running Nvidia driver. And then just follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). The whole process could look like this:
6363
```bash
6464
source .venv/bin/activate
65-
pip install -U "jax[cuda12]==0.4.29"
65+
pip install -U "jax[cuda12]" # specify version as "jax[cuda12]==0.6.2"
6666
```
6767

6868
## Getting Started

cases/ht.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"Heat Transfer over a flat plate setup"
22

3-
43
import jax.numpy as jnp
54
import numpy as np
65
from omegaconf import DictConfig

cases/ldc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Lid-driven cavity case setup"""
22

3-
43
import jax.numpy as jnp
54
import numpy as np
65
from omegaconf import DictConfig

cases/tgv.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Taylor-Green case setup"""
22

3-
43
import jax.numpy as jnp
54
import numpy as np
65
from omegaconf import DictConfig

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
absl-py>=2.1.0
22
h5py
3-
jax[cpu]==0.4.29
3+
jax[cpu]==0.6.2
44
jraph>=0.0.6.dev0
55
omegaconf
66
pandas

jax_sph/case_setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def initialize(self):
191191
assert state[k][mask].shape == _state[k][_mask].shape, ValueError(
192192
f"Shape mismatch for key {k} in state0 file."
193193
)
194-
state[k][mask] = _state[k][_mask]
194+
state[k] = state[k].at[mask].set(_state[k][_mask])
195195

196196
# the following arguments are needed for dataset generation
197197
cfg.case.c_ref, cfg.case.p_ref, cfg.case.p_bg = c_ref, p_ref, p_bg

0 commit comments

Comments
 (0)