Skip to content

Commit bcc68c2

Browse files
committed
reformat
1 parent ae1635d commit bcc68c2

File tree

3 files changed

+81
-63
lines changed

3 files changed

+81
-63
lines changed

fno/data_gen/data_gen_fno.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch_cfd.forcings import *
2525
from fno.pipeline import DATA_PATH, LOG_PATH
2626

27+
2728
def main(args):
2829
"""
2930
Generate the original FNO data
@@ -73,7 +74,7 @@ def main(args):
7374
f"Grid size {n} is larger than the maximum allowed {n_grid_max}"
7475
)
7576
scale = args.scale
76-
visc = args.visc if args.Re is None else 1/args.Re # 1e-3
77+
visc = args.visc if args.Re is None else 1 / args.Re # 1e-3
7778
T = args.time # 50
7879
T_warmup = args.time_warmup # 30
7980
T_new = T - T_warmup
@@ -88,7 +89,7 @@ def main(args):
8889
alpha = args.alpha # 2.5
8990
tau = args.tau # 7
9091
peak_wavenumber = args.peak_wavenumber
91-
92+
9293
dtype = torch.float64 if args.double else torch.float32
9394
normalize = args.normalize
9495
filename = args.filename
@@ -107,11 +108,11 @@ def main(args):
107108
dtype_str = "_fp64" if args.double else ""
108109
if filename is None:
109110
filename = (
110-
f"fnodata{extra}{dtype_str}_{ns}x{ns}_N{total_samples}"
111+
f"fnodata{extra}{dtype_str}_{ns}x{ns}_N{total_samples}"
111112
+ f"_v{visc:.0e}_T{int(T)}_steps{record_steps}_alpha{alpha:.1f}_tau{tau:.0f}.pt"
112113
).replace("e-0", "e-")
113114
args.filename = filename
114-
115+
115116
filepath = args.filepath if args.filepath is not None else DATA_PATH
116117
for p in [filepath]:
117118
if not os.path.exists(p):
@@ -123,7 +124,7 @@ def main(args):
123124
if data_exist and not force_rerun:
124125
logger.info(f"File {filename} exists with current data as follows:")
125126
data = torch.load(data_filepath)
126-
127+
127128
for key, v in data.items():
128129
if isinstance(v, torch.Tensor):
129130
logger.info(f"{key:<12} | {v.shape} | {v.dtype}")
@@ -141,13 +142,15 @@ def main(args):
141142
device = torch.device("cuda:0" if cuda else "cpu")
142143

143144
torch.set_default_dtype(torch.float64)
144-
logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}")
145+
logger.info(
146+
f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}"
147+
)
145148
# Set up 2d GRF with covariance parameters
146149
# Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha)
147150
# Note that we need alpha > d/2 (here d= 2)
148151

149152
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
150-
153+
151154
forcing_fn = SinCosForcing(
152155
grid=grid,
153156
scale=scale,
@@ -156,7 +159,7 @@ def main(args):
156159
vorticity=True,
157160
)
158161
# Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
159-
162+
160163
grf = GRF2d(
161164
n=n,
162165
alpha=alpha,
@@ -187,12 +190,16 @@ def main(args):
187190
num_batches = total_samples // batch_size
188191
for i, idx in enumerate(range(0, total_samples, batch_size)):
189192
logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]")
190-
logger.info(f"random states: {args.seed + idx} to {args.seed + idx + batch_size-1}")
193+
logger.info(
194+
f"random states: {args.seed + idx} to {args.seed + idx + batch_size-1}"
195+
)
191196

192197
# Sample random fields
193198
seeds = [args.seed + idx + k for k in range(batch_size)]
194199
n0 = n_grid_max if replicate_init else n
195-
vort_init = [grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds)]
200+
vort_init = [
201+
grf.sample(1, n0, random_state=s) for _, s in zip(range(batch_size), seeds)
202+
]
196203
vort_init = torch.stack(vort_init)
197204
if n != n0:
198205
vort_init = F.interpolate(vort_init, size=(n, n), mode="nearest")
@@ -230,7 +237,9 @@ def main(args):
230237
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
231238
)
232239
if subsample > 1:
233-
assert value.ndim == 4, f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
240+
assert (
241+
value.ndim == 4
242+
), f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
234243
value = F.interpolate(value, size=(ns, ns), mode="bilinear")
235244
result[field] = value
236245
logger.info(f"{field:<15} | {value.shape} | {value.dtype}")
@@ -250,7 +259,7 @@ def main(args):
250259
try:
251260
verify_trajectories(
252261
data_filepath,
253-
dt=T_new/record_steps,
262+
dt=T_new / record_steps,
254263
T_warmup=T_warmup,
255264
n_samples=1,
256265
)

fno/datasets.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _set_params(self, **params):
4444
setattr(self, k, v)
4545
return self
4646

47-
def _fit_transform(self, x:torch.Tensor):
47+
def _fit_transform(self, x: torch.Tensor):
4848
mean = torch.as_tensor(x.mean(0), dtype=torch.float32)
4949
std = torch.as_tensor(x.std(0), dtype=torch.float32)
5050
x_transformed = (x - mean) / (std + self.eps)
@@ -55,7 +55,7 @@ def _fit_transform(self, x:torch.Tensor):
5555
def fit_transform(self, *args, **kwargs):
5656
return self._fit_transform(*args, **kwargs)
5757

58-
def _transform(self, x:torch.Tensor, align_shapes=False, **kwargs):
58+
def _transform(self, x: torch.Tensor, align_shapes=False, **kwargs):
5959
if hasattr(self, "mean"):
6060
mean, std = self.mean, self.std
6161
if align_shapes:
@@ -68,7 +68,9 @@ def _transform(self, x:torch.Tensor, align_shapes=False, **kwargs):
6868
def transform(self, *args, **kwargs):
6969
return self._transform(*args, **kwargs)
7070

71-
def inverse_transform(self, x:torch.Tensor, sample_idx=None, align_shapes=True, **kwargs):
71+
def inverse_transform(
72+
self, x: torch.Tensor, sample_idx=None, align_shapes=True, **kwargs
73+
):
7274
std = (self.std + self.eps).to(x.device)
7375
mean = self.mean.to(x.device)
7476
if align_shapes:
@@ -88,7 +90,7 @@ def forward(self, *args, **kwargs):
8890
return self.inverse_transform(*args, **kwargs)
8991

9092
@staticmethod
91-
def _align_shapes(x:torch.Tensor, mean:torch.Tensor, std:torch.Tensor, **kwargs):
93+
def _align_shapes(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, **kwargs):
9294
"""
9395
x: (bsz, m, m, C) or (bsz, m, m) or (bsz, C, m, m)
9496
mean: (n, n, C) or (n, n) or (C, n, n)
@@ -111,7 +113,7 @@ def __init__(self, eps=1e-7):
111113
"""
112114
self.device = None
113115

114-
def _fit_transform(self, x:torch.Tensor):
116+
def _fit_transform(self, x: torch.Tensor):
115117
mean = x.mean((0, -1)).unsqueeze(-1)
116118
std = x.std((0, -1)).unsqueeze(-1)
117119
self.register_buffer("mean", mean)

0 commit comments

Comments
 (0)