2424from torch_cfd .forcings import *
2525from fno .pipeline import DATA_PATH , LOG_PATH
2626
27+
2728def main (args ):
2829 """
2930 Generate the original FNO data
@@ -73,7 +74,7 @@ def main(args):
7374 f"Grid size { n } is larger than the maximum allowed { n_grid_max } "
7475 )
7576 scale = args .scale
76- visc = args .visc if args .Re is None else 1 / args .Re # 1e-3
77+ visc = args .visc if args .Re is None else 1 / args .Re # 1e-3
7778 T = args .time # 50
7879 T_warmup = args .time_warmup # 30
7980 T_new = T - T_warmup
@@ -88,7 +89,7 @@ def main(args):
8889 alpha = args .alpha # 2.5
8990 tau = args .tau # 7
9091 peak_wavenumber = args .peak_wavenumber
91-
92+
9293 dtype = torch .float64 if args .double else torch .float32
9394 normalize = args .normalize
9495 filename = args .filename
@@ -107,11 +108,11 @@ def main(args):
107108 dtype_str = "_fp64" if args .double else ""
108109 if filename is None :
109110 filename = (
110- f"fnodata{ extra } { dtype_str } _{ ns } x{ ns } _N{ total_samples } "
111+ f"fnodata{ extra } { dtype_str } _{ ns } x{ ns } _N{ total_samples } "
111112 + f"_v{ visc :.0e} _T{ int (T )} _steps{ record_steps } _alpha{ alpha :.1f} _tau{ tau :.0f} .pt"
112113 ).replace ("e-0" , "e-" )
113114 args .filename = filename
114-
115+
115116 filepath = args .filepath if args .filepath is not None else DATA_PATH
116117 for p in [filepath ]:
117118 if not os .path .exists (p ):
@@ -123,7 +124,7 @@ def main(args):
123124 if data_exist and not force_rerun :
124125 logger .info (f"File { filename } exists with current data as follows:" )
125126 data = torch .load (data_filepath )
126-
127+
127128 for key , v in data .items ():
128129 if isinstance (v , torch .Tensor ):
129130 logger .info (f"{ key :<12} | { v .shape } | { v .dtype } " )
@@ -141,13 +142,15 @@ def main(args):
141142 device = torch .device ("cuda:0" if cuda else "cpu" )
142143
143144 torch .set_default_dtype (torch .float64 )
144- logger .info (f"Using device: { device } | save dtype: { dtype } | computge dtype: { torch .get_default_dtype ()} " )
145+ logger .info (
146+ f"Using device: { device } | save dtype: { dtype } | computge dtype: { torch .get_default_dtype ()} "
147+ )
145148 # Set up 2d GRF with covariance parameters
146149 # Parameters of covariance C = tau^0.5*(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha)
147150 # Note that we need alpha > d/2 (here d= 2)
148151
149152 grid = Grid (shape = (n , n ), domain = ((0 , diam ), (0 , diam )), device = device )
150-
153+
151154 forcing_fn = SinCosForcing (
152155 grid = grid ,
153156 scale = scale ,
@@ -156,7 +159,7 @@ def main(args):
156159 vorticity = True ,
157160 )
158161 # Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
159-
162+
160163 grf = GRF2d (
161164 n = n ,
162165 alpha = alpha ,
@@ -187,12 +190,16 @@ def main(args):
187190 num_batches = total_samples // batch_size
188191 for i , idx in enumerate (range (0 , total_samples , batch_size )):
189192 logger .info (f"Generate trajectory for batch [{ i + 1 } /{ num_batches } ]" )
190- logger .info (f"random states: { args .seed + idx } to { args .seed + idx + batch_size - 1 } " )
193+ logger .info (
194+ f"random states: { args .seed + idx } to { args .seed + idx + batch_size - 1 } "
195+ )
191196
192197 # Sample random fields
193198 seeds = [args .seed + idx + k for k in range (batch_size )]
194199 n0 = n_grid_max if replicate_init else n
195- vort_init = [grf .sample (1 , n0 , random_state = s ) for _ , s in zip (range (batch_size ), seeds )]
200+ vort_init = [
201+ grf .sample (1 , n0 , random_state = s ) for _ , s in zip (range (batch_size ), seeds )
202+ ]
196203 vort_init = torch .stack (vort_init )
197204 if n != n0 :
198205 vort_init = F .interpolate (vort_init , size = (n , n ), mode = "nearest" )
@@ -230,7 +237,9 @@ def main(args):
230237 f"variable: { field } | shape: { value .shape } | dtype: { value .dtype } "
231238 )
232239 if subsample > 1 :
233- assert value .ndim == 4 , f"Subsampling only works for (b, c, h, w) tensors, current shape: { value .shape } "
240+ assert (
241+ value .ndim == 4
242+ ), f"Subsampling only works for (b, c, h, w) tensors, current shape: { value .shape } "
234243 value = F .interpolate (value , size = (ns , ns ), mode = "bilinear" )
235244 result [field ] = value
236245 logger .info (f"{ field :<15} | { value .shape } | { value .dtype } " )
@@ -250,7 +259,7 @@ def main(args):
250259 try :
251260 verify_trajectories (
252261 data_filepath ,
253- dt = T_new / record_steps ,
262+ dt = T_new / record_steps ,
254263 T_warmup = T_warmup ,
255264 n_samples = 1 ,
256265 )
0 commit comments