Skip to content

Commit 356f6a6

Browse files
committed
Refactor equations' __call__ to nn.Module
1 parent e898a2a commit 356f6a6

File tree

7 files changed

+535
-182
lines changed

7 files changed

+535
-182
lines changed

fno/data_gen/data_gen_Kolmogorov2d.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,25 @@ def main(args):
5151
log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log")
5252
logger = get_logger(log_filename)
5353

54+
logger.info(f"Using the following arguments: ")
55+
all_args = {k: v for k, v in vars(args).items() if not callable(v)}
56+
logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items()))
57+
5458
total_samples = args.num_samples
5559
batch_size = args.batch_size # 128
5660
n = args.grid_size # 256
5761
scale = args.scale
5862
viscosity = args.visc
5963
dt = args.dt # 1e-3
6064
T = args.time # 10
61-
subsample = args.subsample # 4
62-
ns = n // subsample
6365
T_warmup = args.time_warmup # 4.5
6466
num_snapshots = args.num_steps # 100
67+
subsample = args.subsample # 4
68+
ns = n // subsample
6569
random_state = args.seed
6670
peak_wavenumber = args.peak_wavenumber # 4
67-
diam = (
68-
eval(args.diam) if isinstance(args.diam, str) else args.diam
69-
) # "2 * torch.pi"
71+
diam = args.diam # "2 * torch.pi" default
72+
diam = eval(diam) if isinstance(diam, str) else diam #
7073
force_rerun = args.force_rerun
7174

7275
logger = logging.getLogger()
@@ -89,21 +92,29 @@ def main(args):
8992
)
9093
args.filename = filename
9194
data_filepath = os.path.join(DATA_PATH, filename)
92-
if os.path.exists(data_filepath) and not force_rerun:
93-
logger.info(f"Data already exists at {data_filepath}")
94-
return
95-
elif os.path.exists(data_filepath) and force_rerun:
96-
logger.info(f"Force rerun and save data to {data_filepath}")
97-
os.remove(data_filepath)
95+
data_exist = os.path.exists(data_filepath)
96+
if data_exist and not force_rerun:
97+
logger.info(f"File {filename} exists with current data as follows:")
98+
data = torch.load(data_filepath)
99+
100+
for key, v in data.items():
101+
if isinstance(v, torch.Tensor):
102+
logger.info(f"{key:<12} | {v.shape} | {v.dtype}")
103+
else:
104+
logger.info(f"{key:<12} | {v.dtype}")
105+
if len(data[key]) == total_samples:
106+
return
107+
elif len(data[key]) < total_samples:
108+
total_samples -= len(data[key])
98109
else:
99-
logger.info(f"Save data to {data_filepath}")
110+
logger.info(f"Generating data and saving in {filename}")
100111

101112
cuda = not args.no_cuda and torch.cuda.is_available()
102113
no_tqdm = args.no_tqdm
103114
device = torch.device("cuda:0" if cuda else "cpu")
104115

105-
torch.set_default_dtype(dtype)
106-
logger.info(f"Using device: {device} | dtype: {dtype}")
116+
torch.set_default_dtype(torch.float64)
117+
logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}")
107118

108119
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
109120

@@ -120,7 +131,7 @@ def main(args):
120131
drag=0.1,
121132
smooth=True,
122133
forcing_fn=forcing_fn,
123-
solver=RK4CrankNicholson,
134+
solver=RK4CrankNicolsonStepper,
124135
).to(device)
125136

126137
num_batches = total_samples // batch_size
@@ -156,8 +167,8 @@ def main(args):
156167
)
157168
pbar.set_description(desc)
158169
pbar.update(100)
159-
160-
result = get_trajectory_rk4(
170+
logger.info(f"generate data from {T_warmup} to {T}")
171+
result = get_trajectory_imex(
161172
ns2d,
162173
vort_hat,
163174
dt,
@@ -172,9 +183,9 @@ def main(args):
172183
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
173184
)
174185
if subsample > 1:
175-
result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear")
176-
else:
177-
result[field] = value
186+
assert value.ndim == 4, f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
187+
value = F.interpolate(value, size=(ns, ns), mode="bilinear")
188+
result[field] = value
178189

179190
result["random_states"] = torch.tensor(
180191
[random_state + idx + k for k in range(batch_size)], dtype=torch.int32
@@ -198,5 +209,5 @@ def main(args):
198209

199210

200211
if __name__ == "__main__":
201-
args = get_args("Params Kolmogorov 2d flow data generation")
212+
args = get_args_ns2d("Params Kolmogorov 2d flow data generation")
202213
main(args)

fno/data_gen/data_gen_McWilliams2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main(args):
117117
drag=0,
118118
smooth=True,
119119
forcing_fn=None,
120-
solver=RK4CrankNicholson,
120+
solver=RK4CrankNicolsonStepper,
121121
).to(device)
122122

123123
num_batches = total_samples // batch_size
@@ -145,7 +145,7 @@ def main(args):
145145
pbar.set_description(desc)
146146
pbar.update(100)
147147

148-
result = get_trajectory_rk4(
148+
result = get_trajectory_imex(
149149
ns2d,
150150
vort_hat,
151151
dt,
@@ -191,5 +191,5 @@ def main(args):
191191

192192

193193
if __name__ == "__main__":
194-
args = get_args("Meta parameters for generating NSE 2d with McWilliams IV")
194+
args = get_args_ns2d("Meta parameters for generating NSE 2d with McWilliams IV")
195195
main(args)

fno/data_gen/data_gen_fno.py

Lines changed: 87 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# The MIT License (MIT)
2+
# Copyright © 2025 Shuhao Cao
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5+
6+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7+
8+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9+
110
import argparse
211
import math
312
import os
@@ -10,6 +19,9 @@
1019
from grf import GRF2d
1120
from solvers import *
1221
from data_utils import *
22+
from torch_cfd.grids import *
23+
from torch_cfd.equations import *
24+
from torch_cfd.forcings import *
1325
from fno.pipeline import DATA_PATH, LOG_PATH
1426

1527
def main(args):
@@ -25,13 +37,16 @@ def main(args):
2537
Sample usage:
2638
2739
- Training data for Spectral-Refiner ICLR 2025 paper 'fnodata_extra_64x64_N1280_v1e-3_T50_steps100_alpha2.5_tau7.pt'
28-
>>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3
40+
>>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 --scale 0.1
2941
3042
- Test data
31-
>>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --replicable-init --seed 42
43+
>>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --scale 0.1 --replicable-init --seed 42
3244
3345
- Test data fine
34-
>>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --replicable-init --seed 42
46+
>>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --scale 0.1 --replicable-init --seed 42
47+
48+
- Testing if the code works
49+
>>> python data_gen/data_gen_fno.py --num-samples 4 --batch-size 2 --grid-size 128 --subsample 1 --double --extra-vars --time 2 --time-warmup 1 --num-steps 10 --dt 1e-3 --scale 0.1 --replicable-init --seed 42
3550
3651
"""
3752

@@ -43,9 +58,6 @@ def main(args):
4358
log_filename = os.path.join(logpath, f"{current_time}_{log_name}.log")
4459
logger = get_logger(log_filename)
4560

46-
cuda = not args.no_cuda and torch.cuda.is_available()
47-
device = torch.device("cuda" if cuda else "cpu")
48-
logger.info(f"Using device: {device}")
4961
logger.info(f"Using the following arguments: ")
5062
all_args = {k: v for k, v in vars(args).items() if not callable(v)}
5163
logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items()))
@@ -60,40 +72,37 @@ def main(args):
6072
raise ValueError(
6173
f"Grid size {n} is larger than the maximum allowed {n_grid_max}"
6274
)
75+
scale = args.scale
6376
visc = args.visc if args.Re is None else 1/args.Re # 1e-3
6477
T = args.time # 50
6578
T_warmup = args.time_warmup # 30
6679
T_new = T - T_warmup
67-
delta_t = args.dt # 1e-4
80+
record_steps = args.num_steps
81+
dt = args.dt # 1e-4
82+
logger.info(f"Using dt = {dt}")
83+
84+
warmup_steps = int(T_warmup / dt)
85+
total_steps = int(T_new / dt)
86+
record_every_iters = int(total_steps / record_steps)
6887

6988
alpha = args.alpha # 2.5
7089
tau = args.tau # 7
71-
f = args.forcing # FNO's default sin+cos
90+
peak_wavenumber = args.peak_wavenumber
91+
7292
dtype = torch.float64 if args.double else torch.float32
7393
normalize = args.normalize
7494
filename = args.filename
7595
force_rerun = args.force_rerun
7696
replicate_init = args.replicable_init
7797
dealias = not args.no_dealias
7898
pbar = not args.no_tqdm
79-
torch.set_default_dtype(torch.float64)
80-
logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}")
8199

82100
# Number of solutions to generate
83101
total_samples = args.num_samples # 8
84102

85-
# Number of snapshots from solution
86-
record_steps = args.num_steps
87-
88103
# Batch size
89104
batch_size = args.batch_size # 8
90105

91-
solver_kws = dict(visc=visc,
92-
delta_t=delta_t,
93-
diam=diam,
94-
dealias=dealias,
95-
dtype=torch.float64)
96-
97106
extra = "_extra" if args.extra_vars else ""
98107
dtype_str = "_fp64" if args.double else ""
99108
if filename is None:
@@ -127,9 +136,27 @@ def main(args):
127136
else:
128137
logger.info(f"Generating data and saving in {filename}")
129138

139+
cuda = not args.no_cuda and torch.cuda.is_available()
140+
no_tqdm = args.no_tqdm
141+
device = torch.device("cuda:0" if cuda else "cpu")
142+
143+
torch.set_default_dtype(torch.float64)
144+
logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}")
130145
# Set up 2d GRF with covariance parameters
131146
# Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha)
132147
# Note that we need alpha > d/2 (here d= 2)
148+
149+
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
150+
151+
forcing_fn = SinCosForcing(
152+
grid=grid,
153+
scale=scale,
154+
diam=diam,
155+
k=peak_wavenumber,
156+
vorticity=True,
157+
)
158+
# Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
159+
133160
grf = GRF2d(
134161
n=n,
135162
alpha=alpha,
@@ -139,14 +166,14 @@ def main(args):
139166
dtype=torch.float64,
140167
)
141168

142-
# Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
143-
grid = torch.linspace(0, 1, n + 1, device=device)
144-
grid = grid[0:-1]
145-
146-
X, Y = torch.meshgrid(grid, grid, indexing="ij")
147-
# FNO's original implementation
148-
# fh = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y)))
149-
fh = f(X, Y)
169+
ns2d = NavierStokes2DSpectral(
170+
viscosity=visc,
171+
grid=grid,
172+
smooth=True,
173+
forcing_fn=forcing_fn,
174+
solver=IMEXStepper,
175+
order=2,
176+
).to(device)
150177

151178
if os.path.exists(data_filepath) and not force_rerun:
152179
logger.info(f"Data already exists at {data_filepath}")
@@ -165,46 +192,48 @@ def main(args):
165192
# Sample random fields
166193
seeds = [args.seed + idx + k for k in range(batch_size)]
167194
n0 = n_grid_max if replicate_init else n
168-
w0 = [grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds)]
169-
w0 = torch.stack(w0)
195+
vort_init = [grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds)]
196+
vort_init = torch.stack(vort_init)
170197
if n != n0:
171-
w0 = F.interpolate(w0, size=(n, n), mode="nearest")
172-
w0 = w0.squeeze(1)
173-
174-
logger.info(f"initial condition {w0.shape}")
198+
vort_init = F.interpolate(vort_init, size=(n, n), mode="nearest")
199+
vort_init = vort_init.squeeze(1)
200+
vort_hat = fft.rfft2(vort_init).to(device)
201+
202+
logger.info(f"initial condition {vort_init.shape}")
175203

176204
if T_warmup > 0:
177-
logger.info(f"warm up till {T_warmup}")
178-
tmp = get_trajectory_imex_crank_nicolson(
179-
w0,
180-
fh,
181-
T=T_warmup,
182-
record_steps=record_steps,
183-
subsample=1,
184-
pbar=pbar,
185-
**solver_kws,
186-
)
187-
w0 = tmp["vorticity"][:, -1].to(device)
188-
del tmp
189-
logger.info(f"warmup initial condition {w0.shape}")
205+
with tqdm(total=warmup_steps, disable=no_tqdm) as pbar:
206+
for j in range(warmup_steps):
207+
vort_hat, _ = ns2d.step(vort_hat, dt)
208+
if j % 100 == 0:
209+
vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n
210+
desc = (
211+
datetime.now().strftime("%d-%b-%Y %H:%M:%S")
212+
+ f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}"
213+
)
214+
pbar.set_description(desc)
215+
pbar.update(100)
190216

191217
logger.info(f"generate data from {T_warmup} to {T}")
192-
result = get_trajectory_imex_crank_nicolson(
193-
w0,
194-
fh,
195-
T=T_new,
196-
record_steps=record_steps,
197-
subsample=subsample,
198-
pbar=pbar,
199-
**solver_kws,
218+
result = get_trajectory_imex(
219+
ns2d,
220+
vort_hat,
221+
dt,
222+
num_steps=total_steps,
223+
record_every_steps=record_every_iters,
224+
pbar=not no_tqdm,
200225
)
201226

202227
for field, value in result.items():
203-
if subsample > 1 and value.ndim == 4:
228+
value = fft.irfft2(value).real.cpu().to(dtype)
229+
logger.info(
230+
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
231+
)
232+
if subsample > 1:
233+
assert value.ndim == 4, f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
204234
value = F.interpolate(value, size=(ns, ns), mode="bilinear")
205-
result[field] = value.cpu().to(dtype)
235+
result[field] = value
206236
logger.info(f"{field:<15} | {value.shape} | {value.dtype}")
207-
208237

209238
if not extra:
210239
for key in ["vort_t", "stream", "residual"]:
@@ -233,5 +262,5 @@ def main(args):
233262

234263

235264
if __name__ == "__main__":
236-
args = get_args("Generate the original FNO data for NSE in 2D")
265+
args = get_args_ns2d("Generate the original FNO data for NSE in 2D")
237266
main(args)

0 commit comments

Comments
 (0)