Skip to content

Commit 579a764

Browse files
committed
fixed a mis-aligned args in SFNO
1 parent c8d7d10 commit 579a764

File tree

8 files changed

+53940
-33
lines changed

8 files changed

+53940
-33
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Neural Operator-Assisted Computational Fluid Dynamics in PyTorch
22

3+
![A decaying turbulence (McWilliams 1984)](examples/McWilliams2d.svg)
4+
35
## Summary
46

57
This repository contains mainly two parts:

examples/Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -14,7 +14,7 @@
1414
"from torch_cfd.initial_conditions import *\n",
1515
"from torch_cfd.finite_differences import *\n",
1616
"from torch_cfd.forcings import *\n",
17-
"from sfno.data_gen import get_trajectory_rk4\n",
17+
"from fno.data_gen import get_trajectory_rk4\n",
1818
"\n",
1919
"import xarray\n",
2020
"import seaborn as sns\n",

examples/McWilliams2d.svg

Lines changed: 53888 additions & 0 deletions
Loading

examples/ex2_SFNO_finetune_McWilliams2d.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
],
129129
"source": [
130130
"torch.cuda.empty_cache()\n",
131-
"model = SFNO(modes, modes, modes_t, width, beta).to(device)\n",
131+
"model = SFNO(modes, modes, modes_t, width, beta=beta).to(device)\n",
132132
"print(get_num_params(model))\n",
133133
"\n",
134134
"model.load_state_dict(torch.load(path_model))\n",
@@ -211,7 +211,7 @@
211211
"outputs": [],
212212
"source": [
213213
"torch.cuda.empty_cache()\n",
214-
"model = SFNO(modes, modes, modes_t, width, beta).to(device)\n",
214+
"model = SFNO(modes, modes, modes_t, width, beta=beta).to(device)\n",
215215
"model.load_state_dict(torch.load(path_model))\n",
216216
"model.to(dtype)\n",
217217
"f = torch.zeros((n, n))[None, ...].to(device)"

examples/ex2_SFNO_finetune_fnodata.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
],
110110
"source": [
111111
"torch.cuda.empty_cache()\n",
112-
"model = SFNO(modes, modes, modes_t, width, beta, output_steps=T_out).to(device)\n",
112+
"model = SFNO(modes, modes, modes_t, width, beta=beta, output_steps=T_out).to(device)\n",
113113
"print(get_num_params(model))\n",
114114
"\n",
115115
"model.load_state_dict(torch.load(path_model))\n",

examples/ex2_SFNO_train.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
},
9696
{
9797
"cell_type": "code",
98-
"execution_count": 12,
98+
"execution_count": null,
9999
"metadata": {},
100100
"outputs": [
101101
{
@@ -108,7 +108,7 @@
108108
],
109109
"source": [
110110
"torch.cuda.empty_cache()\n",
111-
"model = SFNO(modes, modes, modes_t, width, beta).to(device)\n",
111+
"model = SFNO(modes, modes, modes_t, width, beta=beta).to(device)\n",
112112
"\n",
113113
"print(get_num_params(model))\n",
114114
"\n",

examples/ex2_SFNO_train_fnodata.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
],
111111
"source": [
112112
"torch.cuda.empty_cache()\n",
113-
"model = SFNO(modes, modes, modes_t, width, beta, \n",
113+
"model = SFNO(modes, modes, modes_t, width, beta=beta, \n",
114114
" output_steps=T_out).to(device)\n",
115115
"\n",
116116
"print(get_num_params(model))\n",

fno/train.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
from utils import *
2121
from pipeline import *
2222
from data_gen import *
23+
import matplotlib.pyplot as plt
2324
from datasets import BochnerDataset
2425
from losses import SobolevLoss
25-
import matplotlib.pyplot as plt
26-
from fno.sfno import SFNO
2726
from torch.utils.data import DataLoader
2827

28+
from fno.sfno import SFNO
29+
2930
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3031

3132

@@ -44,17 +45,16 @@
4445

4546

4647
def main(args):
47-
48+
4849
current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm")
4950
log_name = "".join(os.path.basename(__file__).split(".")[:-1])
5051

5152
log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log")
5253
logger = get_logger(log_filename)
5354
logger.info(f"Saving log at {log_filename}")
5455

55-
5656
all_args = {k: v for k, v in vars(args).items() if not callable(v)}
57-
logger.info("Arguments: "+" | ".join(f"{k}={v}" for k, v in all_args.items()))
57+
logger.info("Arguments: " + " | ".join(f"{k}={v}" for k, v in all_args.items()))
5858

5959
example = args.example
6060
Ntrain = args.num_samples
@@ -122,14 +122,20 @@ def main(args):
122122
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
123123

124124
torch.cuda.empty_cache()
125-
model = SFNO(modes, modes, modes_t, width, beta,
126-
num_spectral_layers=num_layers,
127-
output_steps=out_steps,
128-
spatial_padding=spatial_padding,
129-
activation=activation,
130-
pe_trainable=pe_trainable,
131-
spatial_random_feats=spatial_random_feats,
132-
lift_activation=lift_activation)
125+
model = SFNO(
126+
modes,
127+
modes,
128+
modes_t,
129+
width,
130+
beta=beta,
131+
num_spectral_layers=num_layers,
132+
output_steps=out_steps,
133+
spatial_padding=spatial_padding,
134+
activation=activation,
135+
pe_trainable=pe_trainable,
136+
spatial_random_feats=spatial_random_feats,
137+
lift_activation=lift_activation,
138+
)
133139
logger.info(f"Number of parameters: {get_num_params(model)}")
134140
model.to(device)
135141

@@ -155,7 +161,9 @@ def main(args):
155161
with tqdm(train_loader) as pbar:
156162
t_ep = datetime.now().strftime("%d-%b-%Y %H:%M:%S")
157163
tr_loss_str = f"current train rel L2: 0.0"
158-
pbar.set_description(f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}")
164+
pbar.set_description(
165+
f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}"
166+
)
159167
for i, data in enumerate(train_loader):
160168
l2 = train_batch_ns(
161169
model,
@@ -173,7 +181,9 @@ def main(args):
173181

174182
if i % 4 == 0:
175183
tr_loss_str = f"current train rel L2: {l2.item():.4e}"
176-
pbar.set_description(f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}")
184+
pbar.set_description(
185+
f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}"
186+
)
177187
pbar.update(4)
178188
val_l2_min = 1e4
179189
val_l2 = eval_epoch_ns(
@@ -214,13 +224,19 @@ def main(args):
214224
)
215225
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
216226
torch.cuda.empty_cache()
217-
model = SFNO(modes, modes, modes_t, width, beta,
218-
num_spectral_layers=num_layers,
219-
spatial_padding=spatial_padding,
220-
activation=activation,
221-
pe_trainable=pe_trainable,
222-
spatial_random_feats=spatial_random_feats,
223-
lift_activation=lift_activation).to(device)
227+
model = SFNO(
228+
modes,
229+
modes,
230+
modes_t,
231+
width,
232+
beta=beta,
233+
num_spectral_layers=num_layers,
234+
spatial_padding=spatial_padding,
235+
activation=activation,
236+
pe_trainable=pe_trainable,
237+
spatial_random_feats=spatial_random_feats,
238+
lift_activation=lift_activation,
239+
).to(device)
224240
model.load_state_dict(torch.load(path_model))
225241
logger.info(f"Loaded model from {path_model}")
226242
eval_metric = SobolevLoss(n_grid=n_test, norm_order=norm_order, relative=True)
@@ -238,20 +254,21 @@ def main(args):
238254
if args.demo_plots > 0:
239255
try:
240256
from visualizations import plot_contour_trajectory
257+
241258
idx = np.random.randint(0, args.num_test_samples)
242259
im1 = plot_contour_trajectory(
243260
preds[idx],
244261
num_snapshots=args.demo_plots,
245262
T_start=args.time_warmup,
246263
dt=args.dt,
247-
title="SFNO predictions"
264+
title="SFNO predictions",
248265
)
249266
im2 = plot_contour_trajectory(
250267
gt_solns[idx],
251268
num_snapshots=args.demo_plots,
252269
T_start=args.time_warmup,
253270
dt=args.dt,
254-
title="Ground truth generated by IMEX"
271+
title="Ground truth generated by IMEX",
255272
)
256273
plt.show()
257274
except Exception as e:
@@ -294,4 +311,4 @@ def main(args):
294311
parser.add_argument("--demo-plots", type=int, default=0)
295312

296313
args = parser.parse_args()
297-
main(args)
314+
main(args)

0 commit comments

Comments
 (0)