Skip to content

Commit ffd636b

Browse files
authored
Merge pull request #2 from scaomath/0.2.0-dev
* Refactor the `grids` and `finite_difference` for batched computation. * Refactored the FVM to have batched computation.
2 parents 475c738 + cc36c16 commit ffd636b

20 files changed

+3493
-3043
lines changed

examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb

Lines changed: 78 additions & 68 deletions
Large diffs are not rendered by default.

examples/Kolmogrov2d_rk4_cn_forced_turbulence.ipynb renamed to examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb

Lines changed: 42 additions & 47 deletions
Large diffs are not rendered by default.

fno/data_gen/data_gen_Kolmogorov2d.py

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,16 @@
1111

1212
import torch
1313
import torch.fft as fft
14+
from torch_cfd.finite_differences import curl_2d
15+
from torch_cfd.forcings import KolmogorovForcing
1416

15-
from torch_cfd.grids import *
17+
from torch_cfd.grids import Grid
18+
from torch_cfd.initial_conditions import filtered_velocity_field
1619
from torch_cfd.equations import *
17-
from torch_cfd.initial_conditions import *
18-
from torch_cfd.finite_differences import *
19-
from torch_cfd.forcings import *
20+
from data_gen.solvers import get_trajectory_imex
2021

2122
from data_utils import *
2223

23-
from solvers import get_trajectory_imex
24-
2524
from fno.pipeline import DATA_PATH, LOG_PATH
2625

2726

@@ -38,6 +37,9 @@ def main(args):
3837
3938
Testing dataset for plotting the enstrohpy spectrum:
4039
>>> python data_gen_Kolmogorov2d.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double
40+
41+
Testing if the data generation works:
42+
>>> python data_gen_Kolmogorov2d.py --num-samples 4 --batch-size 2 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double --demo
4143
"""
4244
args = args.parse_args()
4345

@@ -47,29 +49,29 @@ def main(args):
4749
log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log")
4850
logger = get_logger(log_filename)
4951

50-
logger.info(f"Using the following arguments: ")
51-
all_args = {k: v for k, v in vars(args).items() if not callable(v)}
52-
logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items()))
53-
5452
total_samples = args.num_samples
5553
batch_size = args.batch_size # 128
54+
assert batch_size <= total_samples, "batch_size <= num_samples"
55+
assert total_samples % batch_size == 0, "total_samples divisible by batch_size"
5656
n = args.grid_size # 256
57-
scale = args.scale
58-
viscosity = args.visc
57+
viscosity = args.visc if args.Re is None else 1 / args.Re
58+
Re = 1 / viscosity
5959
dt = args.dt # 1e-3
6060
T = args.time # 10
61-
T_warmup = args.time_warmup # 4.5
62-
num_snapshots = args.num_steps # 100
6361
subsample = args.subsample # 4
6462
ns = n // subsample
63+
scale = args.scale # 1
64+
T_warmup = args.time_warmup # 4.5
65+
num_snapshots = args.num_steps # 100
6566
random_state = args.seed
6667
peak_wavenumber = args.peak_wavenumber # 4
67-
diam = args.diam # "2 * torch.pi" default
68-
diam = eval(diam) if isinstance(diam, str) else diam #
68+
diam = (
69+
eval(args.diam) if isinstance(args.diam, str) else args.diam
70+
) # "2 * torch.pi"
6971
force_rerun = args.force_rerun
7072

7173
logger = logging.getLogger()
72-
logger.info(f"Generating data for Kolmogorov 2d flow with {total_samples} samples")
74+
logger.info(f"Generating data for Kolmogorov2d flow with {total_samples} samples")
7375

7476
max_velocity = args.max_velocity # 5
7577
dt = stable_time_step(diam / n, dt, max_velocity, viscosity=viscosity)
@@ -80,46 +82,37 @@ def main(args):
8082
record_every_iters = int(total_steps / num_snapshots)
8183

8284
dtype = torch.float64 if args.double else torch.float32
85+
cdtype = torch.complex128 if args.double else torch.complex64
8386
dtype_str = "_fp64" if args.double else ""
8487
filename = args.filename
8588
if filename is None:
86-
filename = f"Kolmogorov2d{dtype_str}_{ns}x{ns}_N{total_samples}_v{viscosity:.0e}_T{num_snapshots}.pt".replace(
87-
"e-0", "e-"
88-
)
89+
filename = f"Kolmogorov2d{dtype_str}_{ns}x{ns}_N{total_samples}_Re{int(Re)}_T{num_snapshots}.pt"
8990
args.filename = filename
9091
data_filepath = os.path.join(DATA_PATH, filename)
91-
data_exist = os.path.exists(data_filepath)
92-
if data_exist and not force_rerun:
93-
logger.info(f"File {filename} exists with current data as follows:")
94-
data = torch.load(data_filepath)
95-
96-
for key, v in data.items():
97-
if isinstance(v, torch.Tensor):
98-
logger.info(f"{key:<12} | {v.shape} | {v.dtype}")
99-
else:
100-
logger.info(f"{key:<12} | {v.dtype}")
101-
if len(data[key]) == total_samples:
102-
return
103-
elif len(data[key]) < total_samples:
104-
total_samples -= len(data[key])
92+
if os.path.exists(data_filepath) and not force_rerun:
93+
logger.info(f"Data already exists at {data_filepath}")
94+
return
95+
elif os.path.exists(data_filepath) and force_rerun:
96+
logger.info(f"Force rerun and save data to {data_filepath}")
97+
os.remove(data_filepath)
10598
else:
106-
logger.info(f"Generating data and saving in {filename}")
99+
logger.info(f"Save data to {data_filepath}")
107100

108101
cuda = not args.no_cuda and torch.cuda.is_available()
109102
no_tqdm = args.no_tqdm
110103
device = torch.device("cuda:0" if cuda else "cpu")
111104

112105
torch.set_default_dtype(torch.float64)
113106
logger.info(
114-
f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}"
107+
f"Using device: {device} | save dtype: {dtype} | compute dtype: {torch.get_default_dtype()}"
115108
)
116109

117110
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
118111

119112
forcing_fn = KolmogorovForcing(
120113
grid=grid,
121114
scale=scale,
122-
k=peak_wavenumber,
115+
wave_number=peak_wavenumber,
123116
swap_xy=False,
124117
)
125118

@@ -158,45 +151,47 @@ def main(args):
158151
for j in range(warmup_steps):
159152
vort_hat, _ = ns2d.step(vort_hat, dt)
160153
if j % 100 == 0:
161-
vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n
162-
desc = (
163-
datetime.now().strftime("%d-%b-%Y %H:%M:%S")
164-
+ f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}"
165-
)
154+
vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item()/n
155+
desc = datetime.now().strftime("%d-%b-%Y %H:%M:%S") + f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}"
166156
pbar.set_description(desc)
167157
pbar.update(100)
168-
logger.info(f"generate data from {T_warmup} to {T}")
158+
169159
result = get_trajectory_imex(
170160
ns2d,
171161
vort_hat,
172162
dt,
173163
num_steps=total_steps,
174164
record_every_steps=record_every_iters,
175165
pbar=not no_tqdm,
166+
dtype=cdtype,
176167
)
177168

178169
for field, value in result.items():
170+
logger.info(
171+
f"freq variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}"
172+
)
179173
value = fft.irfft2(value).real.cpu().to(dtype)
180174
logger.info(
181-
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
175+
f"saved variable: {field:<12} | shape: {value.shape} | dtype: {value.dtype}"
182176
)
183177
if subsample > 1:
184-
assert (
185-
value.ndim == 4
186-
), f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
187-
value = F.interpolate(value, size=(ns, ns), mode="bilinear")
188-
result[field] = value
178+
result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear")
179+
else:
180+
result[field] = value
189181

190182
result["random_states"] = torch.tensor(
191183
[random_state + idx + k for k in range(batch_size)], dtype=torch.int32
192184
)
193185
logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}")
194-
save_pickle(result, data_filepath)
195-
del result
196-
197-
pickle_to_pt(data_filepath)
198-
logger.info(f"Done saving.")
199-
if args.demo_plots:
186+
if not args.demo:
187+
save_pickle(result, data_filepath, append=True)
188+
del result
189+
190+
191+
if not args.demo:
192+
pickle_to_pt(data_filepath)
193+
logger.info(f"Done saving.")
194+
else:
200195
try:
201196
verify_trajectories(
202197
data_filepath,
@@ -205,7 +200,8 @@ def main(args):
205200
n_samples=1,
206201
)
207202
except Exception as e:
208-
logger.error(f"Error in plotting: {e}")
203+
logger.error(f"Error in plotting sample trajectories: {e}")
204+
return 0
209205

210206

211207
if __name__ == "__main__":

fno/data_gen/data_gen_McWilliams2d.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import torch
1313
import torch.fft as fft
1414

15-
from torch_cfd.grids import *
15+
from torch_cfd.grids import Grid
16+
from torch_cfd.initial_conditions import vorticity_field
1617
from torch_cfd.equations import *
17-
from torch_cfd.initial_conditions import *
18-
from torch_cfd.finite_differences import *
19-
from torch_cfd.forcings import *
2018

21-
from data_utils import *
2219
from solvers import get_trajectory_imex
20+
from data_utils import *
21+
22+
import logging
2323

2424
from fno.pipeline import DATA_PATH, LOG_PATH
2525

@@ -39,6 +39,9 @@ def main(args):
3939
Training dataset with Re=5k:
4040
>>> python data_gen_McWilliams2d.py --num-samples 1152 --batch-size 128 --grid-size 512 --subsample 1 --Re 5e3 --dt 5e-4 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi"
4141
42+
Demo dataset to test if the data generation works:
43+
>>> python data_gen_McWilliams2d.py --num-samples 4 --batch-size 2 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double --demo
44+
4245
"""
4346
args = args.parse_args()
4447

@@ -84,6 +87,9 @@ def main(args):
8487
dtype_str = "_fp64" if args.double else ""
8588
filename = args.filename
8689
if filename is None:
90+
# filename = f"McWilliams2d{dtype_str}_{ns}x{ns}_N{total_samples}_v{viscosity:.0e}_T{num_snapshots}.pt".replace(
91+
# "e-0", "e-"
92+
# )
8793
filename = f"McWilliams2d{dtype_str}_{ns}x{ns}_N{total_samples}_Re{int(Re)}_T{num_snapshots}.pt"
8894
args.filename = filename
8995
data_filepath = os.path.join(DATA_PATH, filename)
@@ -137,7 +143,11 @@ def main(args):
137143
for j in range(warmup_steps):
138144
vort_hat, _ = ns2d.step(vort_hat, dt)
139145
if j % 100 == 0:
140-
desc = datetime.now().strftime("%d-%b-%Y %H:%M:%S") + " - Warmup"
146+
vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n
147+
desc = (
148+
datetime.now().strftime("%d-%b-%Y %H:%M:%S")
149+
+ f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}"
150+
)
141151
pbar.set_description(desc)
142152
pbar.update(100)
143153

@@ -167,25 +177,28 @@ def main(args):
167177
result["random_states"] = torch.tensor(
168178
[random_state + idx + k for k in range(batch_size)], dtype=torch.int32
169179
)
170-
logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}")
171-
save_pickle(result, data_filepath)
172-
del result
180+
if not args.demo:
181+
save_pickle(result, data_filepath, append=True)
182+
del result
173183

174-
pickle_to_pt(data_filepath)
175-
logger.info(f"Done saving.")
176-
if args.demo_plots:
184+
if not args.demo:
185+
pickle_to_pt(data_filepath)
186+
logger.info(f"Done saving.")
187+
else:
177188
try:
178189
verify_trajectories(
179-
data_filepath,
190+
result,
180191
dt=record_every_iters * dt,
181192
T_warmup=T_warmup,
182193
n_samples=1,
183194
)
184195
except Exception as e:
185-
logger.error(f"Error in plotting: {e}")
196+
logger.error(f"Error in plotting sample trajectories: {e}")
186197
return 0
187198

188199

189200
if __name__ == "__main__":
190-
args = get_args_ns2d("Meta parameters for generating NSE 2d with McWilliams IV")
201+
args = get_args_ns2d(
202+
"Parameters for generating NSE 2d with McWilliams 2d example"
203+
)
191204
main(args)

fno/data_gen/data_gen_fno.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@
77

88
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
99

10-
import argparse
11-
import math
1210
import os
13-
from functools import partial
1411

1512
import torch
1613
import torch.fft as fft
1714
import torch.nn.functional as F
1815

16+
from torch_cfd.grids import Grid
17+
from torch_cfd.equations import *
18+
from torch_cfd.forcings import SinCosForcing
19+
1920
from grf import GRF2d
2021
from solvers import get_trajectory_imex
2122
from data_utils import *
22-
from torch_cfd.grids import *
23-
from torch_cfd.equations import *
24-
from torch_cfd.forcings import *
23+
import logging
24+
2525
from fno.pipeline import DATA_PATH, LOG_PATH
2626

2727

@@ -47,7 +47,7 @@ def main(args):
4747
>>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --scale 0.1 --replicable-init --seed 42
4848
4949
- Testing if the code works
50-
>>> python data_gen/data_gen_fno.py --num-samples 4 --batch-size 2 --grid-size 128 --subsample 1 --double --extra-vars --time 2 --time-warmup 1 --num-steps 10 --dt 1e-3 --scale 0.1 --replicable-init --seed 42
50+
>>> python data_gen/data_gen_fno.py --num-samples 4 --batch-size 2 --grid-size 128 --subsample 1 --double --extra-vars --time 2 --time-warmup 1 --num-steps 10 --dt 1e-3 --scale 0.1 --replicable-init --seed 42 --demo
5151
5252
"""
5353

@@ -168,12 +168,16 @@ def main(args):
168168
device=device,
169169
dtype=torch.float64,
170170
)
171-
step_fn = IMEXStepper(order=2)
171+
172+
step_fn = IMEXStepper(order=2, requires_grad=False)
173+
172174
ns2d = NavierStokes2DSpectral(
173175
viscosity=visc,
174176
grid=grid,
175177
smooth=True,
176178
forcing_fn=forcing_fn,
179+
solver=step_fn,
180+
order=2,
177181
).to(device)
178182

179183
if os.path.exists(data_filepath) and not force_rerun:
@@ -248,24 +252,24 @@ def main(args):
248252
result["random_states"] = torch.as_tensor(seeds, dtype=torch.int32)
249253

250254
logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}")
251-
save_pickle(result, data_filepath)
252-
del result
255+
if not args.demo:
256+
save_pickle(result, data_filepath, append=True)
257+
del result
253258

254-
pickle_to_pt(data_filepath)
255-
logger.info(f"Done converting to pt.")
256-
if args.demo_plots:
259+
if not args.demo:
260+
pickle_to_pt(data_filepath)
261+
logger.info(f"Done saving.")
262+
else:
257263
try:
258264
verify_trajectories(
259-
data_filepath,
260-
dt=T_new / record_steps,
265+
result,
266+
dt=record_every_iters * dt,
261267
T_warmup=T_warmup,
262268
n_samples=1,
263269
)
264270
except Exception as e:
265-
logger.error(f"Error in plotting: {e}")
266-
finally:
267-
pass
268-
return
271+
logger.error(f"Error in plotting sample trajectories: {e}")
272+
return 0
269273

270274

271275
if __name__ == "__main__":

0 commit comments

Comments
 (0)