Skip to content

Commit e203336

Browse files
committed
fixed mcwilliams data gen fp64 bug
1 parent 2cc47e5 commit e203336

File tree

5 files changed

+120
-101
lines changed

5 files changed

+120
-101
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ This repository featuers two parts:
44
- The first part is a native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd). The main changes are documented in the `README.md` under the [`torch_cfd` directory](torch_cfd). Most significant changes in all routines include:
55
- Routines that rely on the functional programming of Jax have been rewritten to be a more debugger-friendly PyTorch tensor-in-tensor-out style.
66
- Functions and operators are in general implemented as `nn.Module` like a factory template.
7-
- Extra fields computation and tracking are made easier, such as time derivatives and PDE residual $R(\boldsymbol{v}):=\boldsymbol{f}-\partial_t \boldsymbol{v}-(\boldsymbol{v}\cdot\nabla)\boldsymbol{v} + \nu \Delta \boldsymbol{v}$.
7+
- Jax-cfd's `funcutils.trajectory` function supports to track only one field variable (vorticity or velocity), Extra fields computation and tracking are made easier, such as time derivatives and PDE residual $R(\boldsymbol{v}):=\boldsymbol{f}-\partial_t \boldsymbol{v}-(\boldsymbol{v}\cdot\nabla)\boldsymbol{v} + \nu \Delta \boldsymbol{v}$.
8+
- All ops takes batch dimension of tensors into consideration, not a single trajectory.
89
- Neural Operator-Assisted Navier-Stokes Equations solver.
910
- The **Spatiotempoeral Fourier Neural Operator** (SFNO) that is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), inspiration drawn from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator).
1011
- Data generation for the meta-example of the isotropic turbulence with energy spectra matching the inverse cascade of Kolmogorov flow in a periodic box. Ref: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43.
@@ -16,7 +17,7 @@ To install `torch-cfd`'s current release, simply do:
1617
```bash
1718
pip install torch-cfd
1819
```
19-
If one wants to play with the neural operator part, it is recommended cloning this repo and playing it local by creating a venv using `requirements.txt`. Note: it is recommended using PyTorch version >=2.0.0 for the broadcasting semantics.
20+
If one wants to play with the neural operator part, it is recommended to clone this repo and play it locally by creating a venv using `requirements.txt`. Note: using PyTorch version >=2.0.0 for the broadcasting semantics.
2021

2122
## Data
2223
The data are available at https://huggingface.co/datasets/scaomath/navier-stokes-dataset
@@ -52,6 +53,6 @@ PR welcome. Currently, the port of `torch-cfd` currently includes:
5253

5354
## Acknowledgments
5455
The research of Brarda and Xi is supported by the National Science Foundation award DMS-2208412.
55-
The work of Li was performed under the auspices of
56+
The work of Li was performed under the auspices of
5657
the U.S. Department of Energy by Lawrence Livermore National Laboratory under Contract DEAC52-07NA27344 and was supported by the LLNL-LDRD program under Project No. 24ERD033. Cao is greatful for the support from [Long Chen (UC Irvine)](https://github.com/lyc102/ifem) and
5758
[Ludmil Zikatanov (Penn State)](https://github.com/HAZmathTeam/hazmath) over the years, and their efforts in open-sourcing scientific computing codes. Cao also appreciates the support from the National Science Foundation DMS-2309778, and the free A6000 credits at the SSE ML cluster from the University of Missouri.

fno/data/data_gen_McWilliams2d.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@
1919
from torch_cfd.initial_conditions import *
2020
from torch_cfd.finite_differences import *
2121
from torch_cfd.forcings import *
22-
23-
from collections import defaultdict
24-
25-
import matplotlib.pyplot as plt
26-
import seaborn as sns
27-
import xarray
2822
from tqdm import tqdm
2923
from .data_gen import *
3024

@@ -79,10 +73,10 @@ def main(args):
7973
filename = args.filename
8074
if filename is None:
8175
filename = (
82-
f"McWilliams2d{dtype_str}"
83-
+ f"_N{total_samples}_n{ns}"
84-
+ f"_v{viscosity:.0e}_T{T}.pt"
85-
).replace("e-0", "e-")
76+
f"McWilliams2d{dtype_str}_{ns}x{ns}"
77+
+ f"_N{total_samples}_v{viscosity:.0e}"
78+
+ f"_T{num_snapshots}.pt".replace("e-0", "e-")
79+
)
8680
args.filename = filename
8781
data_filepath = os.path.join(DATA_PATH, filename)
8882
if os.path.exists(data_filepath) and not force_rerun:
@@ -112,10 +106,10 @@ def main(args):
112106
).to(device)
113107

114108
for i, idx in enumerate(range(0, total_samples, batch_size)):
109+
logger.info(f"Generate trajectory for {i+1}-th batch of {total_samples}")
115110
logger.info(
116-
f"Generate trajectory for {i+1}-th batch of {total_samples}"
111+
f"random state: {random_state + idx} to {random_state + idx + batch_size-1}"
117112
)
118-
logger.info(f"random state: {random_state + idx} to {random_state + idx + batch_size-1}")
119113

120114
vort_init = torch.stack(
121115
[
@@ -143,7 +137,10 @@ def main(args):
143137
)
144138

145139
for field, value in result.items():
146-
value = fft.irfft2(value).real.cpu().to(torch.float32)
140+
value = fft.irfft2(value).real.cpu().to(dtype)
141+
logger.info(
142+
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
143+
)
147144
if subsample > 1:
148145
result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear")
149146
else:
@@ -158,7 +155,9 @@ def main(args):
158155

159156
pickle_to_pt(data_filepath)
160157

161-
verify_trajectories(data_filepath, dt=record_every_iters * dt, T_warmup=T_warmup)
158+
verify_trajectories(
159+
data_filepath, dt=record_every_iters * dt, T_warmup=T_warmup, n_samples=1
160+
)
162161

163162

164163
if __name__ == "__main__":

fno/sfno.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def forward(self, x, out_steps=None):
207207
nx,
208208
ny,
209209
nt // 2 + 1,
210-
dtype=torch.cfloat,
210+
dtype=x_ft.dtype,
211211
device=x.device,
212212
)
213213
slice_x = [slice(0, self.modes_x), slice(-self.modes_x, None)]
@@ -319,6 +319,13 @@ def hook(model, input, output):
319319
for k, block in enumerate(blocks):
320320
block.register_forward_hook(_get_latent_tensors(f"{layer_name}_{k}"))
321321

322+
def double(self):
323+
for param in self.parameters():
324+
if param.dtype == torch.float32:
325+
param.data = param.data.to(torch.float64)
326+
elif param.dtype == torch.complex64:
327+
param.data = param.data.to(torch.complex128)
328+
322329
def forward(self, x, out_steps=None):
323330
"""
324331
if out_steps is None, it will try to use self.out_steps
@@ -327,13 +334,16 @@ def forward(self, x, out_steps=None):
327334
if out_steps is None:
328335
out_steps = self.out_steps if self.out_steps is not None else x.size(-1)
329336
x_res = x # save skip connection
330-
x = self.p(x.unsqueeze(1)) # [b, 1, n, n, T] -> [b, H, n, n, T]
337+
x = self.p(x.unsqueeze(1) ) # [b, 1, n, n, T] -> [b, H, n, n, T]
338+
339+
if self.debug:
340+
print(f"in proj: {x.size()}")
331341

332342
x = F.pad(
333343
x,
334344
[0, 0, self.padding, self.padding, self.padding, self.padding],
335345
mode="circular",
336-
)
346+
) # pad the domain if input is non-periodic
337347

338348
for conv, mlp, w, nonlinear in zip(
339349
self.spectral_conv, self.mlp, self.w, self.activations

fno/visualizations.py

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
13
import plotly.express as px
24
import plotly.figure_factory as ff
35
import plotly.graph_objects as go
4-
5-
import xarray
6-
import numpy as np
6+
import seaborn as sns
77
import torch
88
import torch.fft as fft
9-
import matplotlib.pyplot as plt
10-
import seaborn as sns
9+
10+
import xarray
1111
from mpl_toolkits.axes_grid1 import make_axes_locatable
1212

13-
def plot_contour(z,
14-
func=plt.imshow,
15-
**kwargs):
13+
14+
def plot_contour(z, func=plt.imshow, **kwargs):
1615
if isinstance(z, torch.Tensor):
1716
z = z.cpu().numpy()
1817
_, ax = plt.subplots(figsize=(3, 3))
@@ -23,9 +22,10 @@ def plot_contour(z,
2322
cax = divider.append_axes("right", size="7%", pad=0.1)
2423
cbar = plt.colorbar(f, ax=ax, cax=cax)
2524
cbar.ax.tick_params(labelsize=10)
26-
cbar.ax.locator_params(nbins=9)
25+
cbar.ax.locator_params(nbins=9)
2726
cbar.update_ticks()
2827

28+
2929
def plot_contour_plotly(
3030
z,
3131
colorscale="RdYlBu",
@@ -83,7 +83,7 @@ def plot_contour_plotly(
8383
contour_kwargs["colorbar"] = dict(
8484
thickness=0.15 * layout_kwargs["height"],
8585
tickwidth=0.3,
86-
exponentformat = 'e',
86+
exponentformat="e",
8787
)
8888
layout_kwargs["width"] = 1.32 * layout_kwargs["height"]
8989
else:
@@ -111,122 +111,129 @@ def plot_contour_plotly(
111111
return fig
112112

113113

114-
115114
def get_enstrophy_spectrum(vorticity, h):
116115
if isinstance(vorticity, np.ndarray):
117116
vorticity = torch.from_numpy(vorticity)
118117
n = vorticity.shape[0]
119118
kx = fft.fftfreq(n, d=h)
120119
ky = fft.fftfreq(n, d=h)
121120
kx, ky = torch.meshgrid([kx, ky], indexing="ij")
122-
kmax = n//2
121+
kmax = n // 2
123122
kx = kx[..., : kmax + 1]
124123
ky = ky[..., : kmax + 1]
125-
k2 = (4*torch.pi**2)*(kx**2 + ky**2)
124+
k2 = (4 * torch.pi**2) * (kx**2 + ky**2)
126125
k2[0, 0] = 1.0
127126

128127
wh = fft.rfft2(vorticity)
129128

130-
tke = (0.5*wh*wh.conj()).real
129+
tke = (0.5 * wh * wh.conj()).real
131130
kmod = torch.sqrt(k2)
132-
k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid
131+
k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid
133132
Ens = torch.zeros_like(k)
134-
dk = (torch.max(k)-torch.min(k))/(2*n)
133+
dk = (torch.max(k) - torch.min(k)) / (2 * n)
135134
for i in range(len(k)):
136-
Ens[i] += (tke[(kmod<k[i]+dk) & (kmod>=k[i]-dk)]).sum()
135+
Ens[i] += (tke[(kmod < k[i] + dk) & (kmod >= k[i] - dk)]).sum()
137136

138-
Ens = Ens/Ens.sum()
137+
Ens = Ens / Ens.sum()
139138
return Ens
140139

141140

142-
def plot_enstrophy_spectrum(fields:list,
143-
h=None,
144-
slope=5,
145-
factor=None,
146-
cutoff=1e-15,
147-
plot_cutoff_factor=1/8,
148-
labels=None,
149-
title=None,
150-
legend_loc="upper right",
151-
fontsize=15,
152-
subplot_kw={"figsize": (5, 5), "dpi": 100, "facecolor": "w"},
153-
**kwargs):
141+
def plot_enstrophy_spectrum(
142+
fields: list,
143+
h=None,
144+
slope=5,
145+
factor=None,
146+
cutoff=1e-15,
147+
plot_cutoff_factor=1 / 8,
148+
labels=None,
149+
title=None,
150+
legend_loc="upper right",
151+
fontsize=15,
152+
subplot_kw={"figsize": (5, 5), "dpi": 100, "facecolor": "w"},
153+
**kwargs,
154+
):
154155
for k, field in enumerate(fields):
155156
if isinstance(field, np.ndarray):
156157
fields[k] = torch.from_numpy(field)
157158
if labels is None:
158159
labels = [f"Field {i}" for i in range(len(fields))]
159160
n = fields[0].shape[0]
160-
if h is None: h = 1 / n
161-
kmax = n//2
162-
k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid
161+
if h is None:
162+
h = 1 / n
163+
kmax = n // 2
164+
k = torch.arange(1, kmax, dtype=torch.float64) # Nyquist limit for this grid
163165
Es = [get_enstrophy_spectrum(field, h) for field in fields]
164166
if factor is None:
165-
factor = Es[-1].quantile(0.8)/(k[-1] ** (-slope))
167+
factor = Es[-1].quantile(0.8) / (k[-1] ** (-slope))
166168
# print(factor)
167-
169+
168170
fig, ax = plt.subplots(**subplot_kw)
169-
plot_cutoff = int(n*plot_cutoff_factor)
171+
plot_cutoff = int(n * plot_cutoff_factor)
170172
for i, E in enumerate(Es):
171173
if cutoff is not None:
172174
E[E < cutoff] = np.nan
173175
E[-plot_cutoff:] = np.nan
174176
plt.loglog(k, E, label=f"{labels[i]}")
175177

176-
plt.loglog(k[:-plot_cutoff], (factor*k ** (-slope))[:-plot_cutoff], "b--",
177-
label=f"$O(k^{{{-slope:.3g}}})$",)
178+
plt.loglog(
179+
k[:-plot_cutoff],
180+
(factor * k ** (-slope))[:-plot_cutoff],
181+
"b--",
182+
label=f"$O(k^{{{-slope:.3g}}})$",
183+
)
178184
plt.grid(True, which="both", ls="--", linewidth=0.4)
179-
plt.autoscale(enable=True, axis='x', tight=True)
185+
plt.autoscale(enable=True, axis="x", tight=True)
180186
plt.legend(fontsize=fontsize, loc=legend_loc)
181187
plt.title(title, fontsize=fontsize)
182188
plt.xlabel("Wavenumber", fontsize=fontsize)
183189
ax.xaxis.set_tick_params(labelsize=fontsize)
184190
ax.yaxis.set_tick_params(labelsize=fontsize)
185191

186192

187-
def plot_contour_trajectory(field,
188-
num_snapshots=5,
189-
contourf=False,
190-
T_start=4.5,
191-
dt = 1e-1,
192-
**kwargs):
193+
def plot_contour_trajectory(
194+
field,
195+
num_snapshots=5,
196+
contourf=False,
197+
T_start=4.5,
198+
dt=1e-1,
199+
cb_kws=dict(orientation="vertical", pad=0.01, aspect=10),
200+
subplot_kws=dict(
201+
xticks=[],
202+
yticks=[],
203+
ylabel="",
204+
xlabel="",
205+
),
206+
plot_kws=dict(
207+
col_wrap=5,
208+
cmap=sns.cm.icefire,
209+
robust=True,
210+
add_colorbar=True,
211+
xticks=None,
212+
yticks=None,
213+
size=3,
214+
aspect=1,
215+
),
216+
**kwargs,
217+
):
193218
"""
194219
plot trajectory using xarray's imshow or contourf wrapper
195220
"""
196221
field = field.detach().cpu().numpy()
197222
*size, T = field.shape
198-
grid = np.linspace(0, 1, size[0]+1)[:-1]
223+
grid = np.linspace(0, 1, size[0] + 1)[:-1]
199224
time = np.arange(T) * dt + T_start
200225
coords = {
201-
'x': grid,
202-
'y': grid,
203-
't': time,
226+
"x": grid,
227+
"y": grid,
228+
"t": time,
204229
}
205-
ds = xarray.DataArray(
206-
field,
207-
dims=["x", "y", "t"],
208-
coords=coords
209-
)
210-
t_steps = T//num_snapshots
230+
ds = xarray.DataArray(field, dims=["x", "y", "t"], coords=coords)
231+
t_steps = T // num_snapshots
211232
ds = ds.thin({"t": t_steps})
212233
plot_func = ds.plot.contourf if contourf else ds.plot.imshow
213-
plot_func(col='t',
214-
col_wrap=5,
215-
cmap=sns.cm.icefire,
216-
robust=True,
217-
add_colorbar=True,
218-
xticks=None,
219-
yticks=None,
220-
size=3,
221-
aspect=1,
222-
interpolation="hermite",
223-
subplot_kws=dict(xticks=[],
224-
yticks=[],
225-
ylabel= "",
226-
xlabel= "",
227-
),
228-
cbar_kwargs=dict(orientation="vertical",
229-
pad=0.01,
230-
aspect= 10
231-
),
232-
**kwargs)
234+
plot_func(
235+
col="t",
236+
subplot_kws=subplot_kws,
237+
cbar_kwargs=cb_kws,
238+
**plot_kws,
239+
)

0 commit comments

Comments
 (0)