Skip to content

Commit 28cc4a9

Browse files
committed
removed nested imports
1 parent d02e4bf commit 28cc4a9

File tree

5 files changed

+14
-20
lines changed

5 files changed

+14
-20
lines changed

fno/data_gen/data_gen_Kolmogorov2d.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
99

10-
import os, sys
11-
12-
import dill
10+
import os
1311

1412
import torch
1513
import torch.fft as fft
@@ -20,11 +18,9 @@
2018
from torch_cfd.finite_differences import *
2119
from torch_cfd.forcings import *
2220

23-
from tqdm import tqdm
2421
from data_utils import *
25-
from solvers import *
2622

27-
import logging
23+
from solvers import get_trajectory_imex
2824

2925
from fno.pipeline import DATA_PATH, LOG_PATH
3026

@@ -69,7 +65,7 @@ def main(args):
6965
random_state = args.seed
7066
peak_wavenumber = args.peak_wavenumber # 4
7167
diam = args.diam # "2 * torch.pi" default
72-
diam = eval(diam) if isinstance(diam, str) else diam #
68+
diam = eval(diam) if isinstance(diam, str) else diam #
7369
force_rerun = args.force_rerun
7470

7571
logger = logging.getLogger()
@@ -96,7 +92,7 @@ def main(args):
9692
if data_exist and not force_rerun:
9793
logger.info(f"File {filename} exists with current data as follows:")
9894
data = torch.load(data_filepath)
99-
95+
10096
for key, v in data.items():
10197
if isinstance(v, torch.Tensor):
10298
logger.info(f"{key:<12} | {v.shape} | {v.dtype}")
@@ -114,7 +110,9 @@ def main(args):
114110
device = torch.device("cuda:0" if cuda else "cpu")
115111

116112
torch.set_default_dtype(torch.float64)
117-
logger.info(f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}")
113+
logger.info(
114+
f"Using device: {device} | save dtype: {dtype} | computge dtype: {torch.get_default_dtype()}"
115+
)
118116

119117
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
120118

@@ -183,7 +181,9 @@ def main(args):
183181
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
184182
)
185183
if subsample > 1:
186-
assert value.ndim == 4, f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
184+
assert (
185+
value.ndim == 4
186+
), f"Subsampling only works for (b, c, h, w) tensors, current shape: {value.shape}"
187187
value = F.interpolate(value, size=(ns, ns), mode="bilinear")
188188
result[field] = value
189189

fno/data_gen/data_gen_McWilliams2d.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
99

10-
import os, sys
11-
12-
import dill
10+
import os
1311

1412
import torch
1513
import torch.fft as fft
@@ -20,10 +18,8 @@
2018
from torch_cfd.finite_differences import *
2119
from torch_cfd.forcings import *
2220

23-
from tqdm import tqdm
2421
from data_utils import *
25-
from solvers import *
26-
import logging
22+
from solvers import get_trajectory_imex
2723

2824
from fno.pipeline import DATA_PATH, LOG_PATH
2925

fno/data_gen/data_gen_fno.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818

1919
from grf import GRF2d
20-
from solvers import *
20+
from solvers import get_trajectory_imex
2121
from data_utils import *
2222
from torch_cfd.grids import *
2323
from torch_cfd.equations import *

fno/data_gen/data_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
import numpy as np
1414
import seaborn as sns
1515
import torch
16-
import torch.fft as fft
1716
import xarray
18-
from tqdm import tqdm
17+
from tqdm.auto import tqdm
1918

2019
feval = lambda s: eval("lambda x, y:" + s, globals())
2120

torch_cfd/pressure.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from torch_cfd import finite_differences as fdm
2727

2828

29-
Array = grids.Array
3029
GridArray = grids.GridArray
3130
GridArrayVector = grids.GridArrayVector
3231
GridVariable = grids.GridVariable

0 commit comments

Comments
 (0)