1+ # The MIT License (MIT)
2+ # Copyright © 2025 Shuhao Cao
3+
4+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5+
6+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7+
8+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9+
110import argparse
211import math
312import os
1019from grf import GRF2d
1120from solvers import *
1221from data_utils import *
22+ from torch_cfd .grids import *
23+ from torch_cfd .equations import *
24+ from torch_cfd .forcings import *
1325from fno .pipeline import DATA_PATH , LOG_PATH
1426
1527def main (args ):
@@ -25,13 +37,16 @@ def main(args):
2537 Sample usage:
2638
2739 - Training data for Spectral-Refiner ICLR 2025 paper 'fnodata_extra_64x64_N1280_v1e-3_T50_steps100_alpha2.5_tau7.pt'
28- >>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3
40+ >>> python data_gen_fno.py --num-samples 1280 --batch-size 256 --grid-size 256 --subsample 4 --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --visc 1e-3 --scale 0.1
2941
3042 - Test data
31- >>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --replicable-init --seed 42
43+ >>> python data_gen_fno.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 100 --dt 1e-3 --scale 0.1 -- replicable-init --seed 42
3244
3345 - Test data fine
34- >>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --replicable-init --seed 42
46+ >>> python data_gen_fno.py --num-samples 2 --batch-size 1 --grid-size 512 --subsample 1 --double --extra-vars --time 50 --time-warmup 30 --num-steps 200 --dt 5e-4 --scale 0.1 --replicable-init --seed 42
47+
48+ - Testing if the code works
49+ >>> python data_gen/data_gen_fno.py --num-samples 4 --batch-size 2 --grid-size 128 --subsample 1 --double --extra-vars --time 2 --time-warmup 1 --num-steps 10 --dt 1e-3 --scale 0.1 --replicable-init --seed 42
3550
3651 """
3752
@@ -43,9 +58,6 @@ def main(args):
4358 log_filename = os .path .join (logpath , f"{ current_time } _{ log_name } .log" )
4459 logger = get_logger (log_filename )
4560
46- cuda = not args .no_cuda and torch .cuda .is_available ()
47- device = torch .device ("cuda" if cuda else "cpu" )
48- logger .info (f"Using device: { device } " )
4961 logger .info (f"Using the following arguments: " )
5062 all_args = {k : v for k , v in vars (args ).items () if not callable (v )}
5163 logger .info (" | " .join (f"{ k } ={ v } " for k , v in all_args .items ()))
@@ -60,40 +72,37 @@ def main(args):
6072 raise ValueError (
6173 f"Grid size { n } is larger than the maximum allowed { n_grid_max } "
6274 )
75+ scale = args .scale
6376 visc = args .visc if args .Re is None else 1 / args .Re # 1e-3
6477 T = args .time # 50
6578 T_warmup = args .time_warmup # 30
6679 T_new = T - T_warmup
67- delta_t = args .dt # 1e-4
80+ record_steps = args .num_steps
81+ dt = args .dt # 1e-4
82+ logger .info (f"Using dt = { dt } " )
83+
84+ warmup_steps = int (T_warmup / dt )
85+ total_steps = int (T_new / dt )
86+ record_every_iters = int (total_steps / record_steps )
6887
6988 alpha = args .alpha # 2.5
7089 tau = args .tau # 7
71- f = args .forcing # FNO's default sin+cos
90+ peak_wavenumber = args .peak_wavenumber
91+
7292 dtype = torch .float64 if args .double else torch .float32
7393 normalize = args .normalize
7494 filename = args .filename
7595 force_rerun = args .force_rerun
7696 replicate_init = args .replicable_init
7797 dealias = not args .no_dealias
7898 pbar = not args .no_tqdm
79- torch .set_default_dtype (torch .float64 )
80- logger .info (f"Using device: { device } | save dtype: { dtype } | computge dtype: { torch .get_default_dtype ()} " )
8199
82100 # Number of solutions to generate
83101 total_samples = args .num_samples # 8
84102
85- # Number of snapshots from solution
86- record_steps = args .num_steps
87-
88103 # Batch size
89104 batch_size = args .batch_size # 8
90105
91- solver_kws = dict (visc = visc ,
92- delta_t = delta_t ,
93- diam = diam ,
94- dealias = dealias ,
95- dtype = torch .float64 )
96-
97106 extra = "_extra" if args .extra_vars else ""
98107 dtype_str = "_fp64" if args .double else ""
99108 if filename is None :
@@ -127,9 +136,27 @@ def main(args):
127136 else :
128137 logger .info (f"Generating data and saving in { filename } " )
129138
139+ cuda = not args .no_cuda and torch .cuda .is_available ()
140+ no_tqdm = args .no_tqdm
141+ device = torch .device ("cuda:0" if cuda else "cpu" )
142+
143+ torch .set_default_dtype (torch .float64 )
144+ logger .info (f"Using device: { device } | save dtype: { dtype } | computge dtype: { torch .get_default_dtype ()} " )
130145 # Set up 2d GRF with covariance parameters
131146 # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha)
132147 # Note that we need alpha > d/2 (here d= 2)
148+
149+ grid = Grid (shape = (n , n ), domain = ((0 , diam ), (0 , diam )), device = device )
150+
151+ forcing_fn = SinCosForcing (
152+ grid = grid ,
153+ scale = scale ,
154+ diam = diam ,
155+ k = peak_wavenumber ,
156+ vorticity = True ,
157+ )
158+ # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
159+
133160 grf = GRF2d (
134161 n = n ,
135162 alpha = alpha ,
@@ -139,14 +166,14 @@ def main(args):
139166 dtype = torch .float64 ,
140167 )
141168
142- # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
143- grid = torch . linspace ( 0 , 1 , n + 1 , device = device )
144- grid = grid [ 0 : - 1 ]
145-
146- X , Y = torch . meshgrid ( grid , grid , indexing = "ij" )
147- # FNO's original implementation
148- # fh = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y)))
149- fh = f ( X , Y )
169+ ns2d = NavierStokes2DSpectral (
170+ viscosity = visc ,
171+ grid = grid ,
172+ smooth = True ,
173+ forcing_fn = forcing_fn ,
174+ solver = IMEXStepper ,
175+ order = 2 ,
176+ ). to ( device )
150177
151178 if os .path .exists (data_filepath ) and not force_rerun :
152179 logger .info (f"Data already exists at { data_filepath } " )
@@ -165,46 +192,48 @@ def main(args):
165192 # Sample random fields
166193 seeds = [args .seed + idx + k for k in range (batch_size )]
167194 n0 = n_grid_max if replicate_init else n
168- w0 = [grf .sample (1 , n0 , random_state = s ) for _ , s in zip (range (batch_size ), seeds )]
169- w0 = torch .stack (w0 )
195+ vort_init = [grf .sample (1 , n0 , random_state = s ) for _ , s in zip (range (batch_size ), seeds )]
196+ vort_init = torch .stack (vort_init )
170197 if n != n0 :
171- w0 = F .interpolate (w0 , size = (n , n ), mode = "nearest" )
172- w0 = w0 .squeeze (1 )
173-
174- logger .info (f"initial condition { w0 .shape } " )
198+ vort_init = F .interpolate (vort_init , size = (n , n ), mode = "nearest" )
199+ vort_init = vort_init .squeeze (1 )
200+ vort_hat = fft .rfft2 (vort_init ).to (device )
201+
202+ logger .info (f"initial condition { vort_init .shape } " )
175203
176204 if T_warmup > 0 :
177- logger .info (f"warm up till { T_warmup } " )
178- tmp = get_trajectory_imex_crank_nicolson (
179- w0 ,
180- fh ,
181- T = T_warmup ,
182- record_steps = record_steps ,
183- subsample = 1 ,
184- pbar = pbar ,
185- ** solver_kws ,
186- )
187- w0 = tmp ["vorticity" ][:, - 1 ].to (device )
188- del tmp
189- logger .info (f"warmup initial condition { w0 .shape } " )
205+ with tqdm (total = warmup_steps , disable = no_tqdm ) as pbar :
206+ for j in range (warmup_steps ):
207+ vort_hat , _ = ns2d .step (vort_hat , dt )
208+ if j % 100 == 0 :
209+ vort_norm = torch .linalg .norm (fft .irfft2 (vort_hat )).item () / n
210+ desc = (
211+ datetime .now ().strftime ("%d-%b-%Y %H:%M:%S" )
212+ + f" - Warmup | vort_hat ell2 norm { vort_norm :.4e} "
213+ )
214+ pbar .set_description (desc )
215+ pbar .update (100 )
190216
191217 logger .info (f"generate data from { T_warmup } to { T } " )
192- result = get_trajectory_imex_crank_nicolson (
193- w0 ,
194- fh ,
195- T = T_new ,
196- record_steps = record_steps ,
197- subsample = subsample ,
198- pbar = pbar ,
199- ** solver_kws ,
218+ result = get_trajectory_imex (
219+ ns2d ,
220+ vort_hat ,
221+ dt ,
222+ num_steps = total_steps ,
223+ record_every_steps = record_every_iters ,
224+ pbar = not no_tqdm ,
200225 )
201226
202227 for field , value in result .items ():
203- if subsample > 1 and value .ndim == 4 :
228+ value = fft .irfft2 (value ).real .cpu ().to (dtype )
229+ logger .info (
230+ f"variable: { field } | shape: { value .shape } | dtype: { value .dtype } "
231+ )
232+ if subsample > 1 :
233+ assert value .ndim == 4 , f"Subsampling only works for (b, c, h, w) tensors, current shape: { value .shape } "
204234 value = F .interpolate (value , size = (ns , ns ), mode = "bilinear" )
205- result [field ] = value . cpu (). to ( dtype )
235+ result [field ] = value
206236 logger .info (f"{ field :<15} | { value .shape } | { value .dtype } " )
207-
208237
209238 if not extra :
210239 for key in ["vort_t" , "stream" , "residual" ]:
@@ -233,5 +262,5 @@ def main(args):
233262
234263
235264if __name__ == "__main__" :
236- args = get_args ("Generate the original FNO data for NSE in 2D" )
265+ args = get_args_ns2d ("Generate the original FNO data for NSE in 2D" )
237266 main (args )
0 commit comments