1111
1212import torch
1313import torch .fft as fft
14+ from torch_cfd .finite_differences import curl_2d
15+ from torch_cfd .forcings import KolmogorovForcing
1416
15- from torch_cfd .grids import *
17+ from torch_cfd .grids import Grid
18+ from torch_cfd .initial_conditions import filtered_velocity_field
1619from torch_cfd .equations import *
17- from torch_cfd .initial_conditions import *
18- from torch_cfd .finite_differences import *
19- from torch_cfd .forcings import *
20+ from data_gen .solvers import get_trajectory_imex
2021
2122from data_utils import *
2223
23- from solvers import get_trajectory_imex
24-
2524from fno .pipeline import DATA_PATH , LOG_PATH
2625
2726
@@ -38,6 +37,9 @@ def main(args):
3837
3938 Testing dataset for plotting the enstrohpy spectrum:
4039 >>> python data_gen_Kolmogorov2d.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double
40+
41+ Testing if the data generation works:
42+ >>> python data_gen_Kolmogorov2d.py --num-samples 4 --batch-size 2 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double --demo
4143 """
4244 args = args .parse_args ()
4345
@@ -47,29 +49,29 @@ def main(args):
4749 log_filename = os .path .join (LOG_PATH , f"{ current_time } _{ log_name } .log" )
4850 logger = get_logger (log_filename )
4951
50- logger .info (f"Using the following arguments: " )
51- all_args = {k : v for k , v in vars (args ).items () if not callable (v )}
52- logger .info (" | " .join (f"{ k } ={ v } " for k , v in all_args .items ()))
53-
5452 total_samples = args .num_samples
5553 batch_size = args .batch_size # 128
54+ assert batch_size <= total_samples , "batch_size <= num_samples"
55+ assert total_samples % batch_size == 0 , "total_samples divisible by batch_size"
5656 n = args .grid_size # 256
57- scale = args .scale
58- viscosity = args . visc
57+ viscosity = args .visc if args . Re is None else 1 / args . Re
58+ Re = 1 / viscosity
5959 dt = args .dt # 1e-3
6060 T = args .time # 10
61- T_warmup = args .time_warmup # 4.5
62- num_snapshots = args .num_steps # 100
6361 subsample = args .subsample # 4
6462 ns = n // subsample
63+ scale = args .scale # 1
64+ T_warmup = args .time_warmup # 4.5
65+ num_snapshots = args .num_steps # 100
6566 random_state = args .seed
6667 peak_wavenumber = args .peak_wavenumber # 4
67- diam = args .diam # "2 * torch.pi" default
68- diam = eval (diam ) if isinstance (diam , str ) else diam #
68+ diam = (
69+ eval (args .diam ) if isinstance (args .diam , str ) else args .diam
70+ ) # "2 * torch.pi"
6971 force_rerun = args .force_rerun
7072
7173 logger = logging .getLogger ()
72- logger .info (f"Generating data for Kolmogorov 2d flow with { total_samples } samples" )
74+ logger .info (f"Generating data for Kolmogorov2d flow with { total_samples } samples" )
7375
7476 max_velocity = args .max_velocity # 5
7577 dt = stable_time_step (diam / n , dt , max_velocity , viscosity = viscosity )
@@ -80,46 +82,37 @@ def main(args):
8082 record_every_iters = int (total_steps / num_snapshots )
8183
8284 dtype = torch .float64 if args .double else torch .float32
85+ cdtype = torch .complex128 if args .double else torch .complex64
8386 dtype_str = "_fp64" if args .double else ""
8487 filename = args .filename
8588 if filename is None :
86- filename = f"Kolmogorov2d{ dtype_str } _{ ns } x{ ns } _N{ total_samples } _v{ viscosity :.0e} _T{ num_snapshots } .pt" .replace (
87- "e-0" , "e-"
88- )
89+ filename = f"Kolmogorov2d{ dtype_str } _{ ns } x{ ns } _N{ total_samples } _Re{ int (Re )} _T{ num_snapshots } .pt"
8990 args .filename = filename
9091 data_filepath = os .path .join (DATA_PATH , filename )
91- data_exist = os .path .exists (data_filepath )
92- if data_exist and not force_rerun :
93- logger .info (f"File { filename } exists with current data as follows:" )
94- data = torch .load (data_filepath )
95-
96- for key , v in data .items ():
97- if isinstance (v , torch .Tensor ):
98- logger .info (f"{ key :<12} | { v .shape } | { v .dtype } " )
99- else :
100- logger .info (f"{ key :<12} | { v .dtype } " )
101- if len (data [key ]) == total_samples :
102- return
103- elif len (data [key ]) < total_samples :
104- total_samples -= len (data [key ])
92+ if os .path .exists (data_filepath ) and not force_rerun :
93+ logger .info (f"Data already exists at { data_filepath } " )
94+ return
95+ elif os .path .exists (data_filepath ) and force_rerun :
96+ logger .info (f"Force rerun and save data to { data_filepath } " )
97+ os .remove (data_filepath )
10598 else :
106- logger .info (f"Generating data and saving in { filename } " )
99+ logger .info (f"Save data to { data_filepath } " )
107100
108101 cuda = not args .no_cuda and torch .cuda .is_available ()
109102 no_tqdm = args .no_tqdm
110103 device = torch .device ("cuda:0" if cuda else "cpu" )
111104
112105 torch .set_default_dtype (torch .float64 )
113106 logger .info (
114- f"Using device: { device } | save dtype: { dtype } | computge dtype: { torch .get_default_dtype ()} "
107+ f"Using device: { device } | save dtype: { dtype } | compute dtype: { torch .get_default_dtype ()} "
115108 )
116109
117110 grid = Grid (shape = (n , n ), domain = ((0 , diam ), (0 , diam )), device = device )
118111
119112 forcing_fn = KolmogorovForcing (
120113 grid = grid ,
121114 scale = scale ,
122- k = peak_wavenumber ,
115+ wave_number = peak_wavenumber ,
123116 swap_xy = False ,
124117 )
125118
@@ -158,45 +151,47 @@ def main(args):
158151 for j in range (warmup_steps ):
159152 vort_hat , _ = ns2d .step (vort_hat , dt )
160153 if j % 100 == 0 :
161- vort_norm = torch .linalg .norm (fft .irfft2 (vort_hat )).item () / n
162- desc = (
163- datetime .now ().strftime ("%d-%b-%Y %H:%M:%S" )
164- + f" - Warmup | vort_hat ell2 norm { vort_norm :.4e} "
165- )
154+ vort_norm = torch .linalg .norm (fft .irfft2 (vort_hat )).item ()/ n
155+ desc = datetime .now ().strftime ("%d-%b-%Y %H:%M:%S" ) + f" - Warmup | vort_hat ell2 norm { vort_norm :.4e} "
166156 pbar .set_description (desc )
167157 pbar .update (100 )
168- logger . info ( f"generate data from { T_warmup } to { T } " )
158+
169159 result = get_trajectory_imex (
170160 ns2d ,
171161 vort_hat ,
172162 dt ,
173163 num_steps = total_steps ,
174164 record_every_steps = record_every_iters ,
175165 pbar = not no_tqdm ,
166+ dtype = cdtype ,
176167 )
177168
178169 for field , value in result .items ():
170+ logger .info (
171+ f"freq variable: { field :<12} | shape: { value .shape } | dtype: { value .dtype } "
172+ )
179173 value = fft .irfft2 (value ).real .cpu ().to (dtype )
180174 logger .info (
181- f"variable: { field } | shape: { value .shape } | dtype: { value .dtype } "
175+ f"saved variable: { field :<12 } | shape: { value .shape } | dtype: { value .dtype } "
182176 )
183177 if subsample > 1 :
184- assert (
185- value .ndim == 4
186- ), f"Subsampling only works for (b, c, h, w) tensors, current shape: { value .shape } "
187- value = F .interpolate (value , size = (ns , ns ), mode = "bilinear" )
188- result [field ] = value
178+ result [field ] = F .interpolate (value , size = (ns , ns ), mode = "bilinear" )
179+ else :
180+ result [field ] = value
189181
190182 result ["random_states" ] = torch .tensor (
191183 [random_state + idx + k for k in range (batch_size )], dtype = torch .int32
192184 )
193185 logger .info (f"Saving batch [{ i + 1 } /{ num_batches } ] to { data_filepath } " )
194- save_pickle (result , data_filepath )
195- del result
196-
197- pickle_to_pt (data_filepath )
198- logger .info (f"Done saving." )
199- if args .demo_plots :
186+ if not args .demo :
187+ save_pickle (result , data_filepath , append = True )
188+ del result
189+
190+
191+ if not args .demo :
192+ pickle_to_pt (data_filepath )
193+ logger .info (f"Done saving." )
194+ else :
200195 try :
201196 verify_trajectories (
202197 data_filepath ,
@@ -205,7 +200,8 @@ def main(args):
205200 n_samples = 1 ,
206201 )
207202 except Exception as e :
208- logger .error (f"Error in plotting: { e } " )
203+ logger .error (f"Error in plotting sample trajectories: { e } " )
204+ return 0
209205
210206
211207if __name__ == "__main__" :
0 commit comments