2020from utils import *
2121from pipeline import *
2222from data_gen import *
23+ import matplotlib .pyplot as plt
2324from datasets import BochnerDataset
2425from losses import SobolevLoss
25- import matplotlib .pyplot as plt
26- from fno .sfno import SFNO
2726from torch .utils .data import DataLoader
2827
28+ from fno .sfno import SFNO
29+
2930device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
3031
3132
4445
4546
4647def main (args ):
47-
48+
4849 current_time = datetime .now ().strftime ("%d_%b_%Y_%Hh%Mm" )
4950 log_name = "" .join (os .path .basename (__file__ ).split ("." )[:- 1 ])
5051
5152 log_filename = os .path .join (LOG_PATH , f"{ current_time } _{ log_name } .log" )
5253 logger = get_logger (log_filename )
5354 logger .info (f"Saving log at { log_filename } " )
5455
55-
5656 all_args = {k : v for k , v in vars (args ).items () if not callable (v )}
57- logger .info ("Arguments: " + " | " .join (f"{ k } ={ v } " for k , v in all_args .items ()))
57+ logger .info ("Arguments: " + " | " .join (f"{ k } ={ v } " for k , v in all_args .items ()))
5858
5959 example = args .example
6060 Ntrain = args .num_samples
@@ -122,14 +122,20 @@ def main(args):
122122 val_loader = DataLoader (val_dataset , batch_size = batch_size , shuffle = False )
123123
124124 torch .cuda .empty_cache ()
125- model = SFNO (modes , modes , modes_t , width , beta ,
126- num_spectral_layers = num_layers ,
127- output_steps = out_steps ,
128- spatial_padding = spatial_padding ,
129- activation = activation ,
130- pe_trainable = pe_trainable ,
131- spatial_random_feats = spatial_random_feats ,
132- lift_activation = lift_activation )
125+ model = SFNO (
126+ modes ,
127+ modes ,
128+ modes_t ,
129+ width ,
130+ beta = beta ,
131+ num_spectral_layers = num_layers ,
132+ output_steps = out_steps ,
133+ spatial_padding = spatial_padding ,
134+ activation = activation ,
135+ pe_trainable = pe_trainable ,
136+ spatial_random_feats = spatial_random_feats ,
137+ lift_activation = lift_activation ,
138+ )
133139 logger .info (f"Number of parameters: { get_num_params (model )} " )
134140 model .to (device )
135141
@@ -155,7 +161,9 @@ def main(args):
155161 with tqdm (train_loader ) as pbar :
156162 t_ep = datetime .now ().strftime ("%d-%b-%Y %H:%M:%S" )
157163 tr_loss_str = f"current train rel L2: 0.0"
158- pbar .set_description (f"{ t_ep } - Epoch [{ ep + 1 :3d} /{ epochs } ] { tr_loss_str :>35} " )
164+ pbar .set_description (
165+ f"{ t_ep } - Epoch [{ ep + 1 :3d} /{ epochs } ] { tr_loss_str :>35} "
166+ )
159167 for i , data in enumerate (train_loader ):
160168 l2 = train_batch_ns (
161169 model ,
@@ -173,7 +181,9 @@ def main(args):
173181
174182 if i % 4 == 0 :
175183 tr_loss_str = f"current train rel L2: { l2 .item ():.4e} "
176- pbar .set_description (f"{ t_ep } - Epoch [{ ep + 1 :3d} /{ epochs } ] { tr_loss_str :>35} " )
184+ pbar .set_description (
185+ f"{ t_ep } - Epoch [{ ep + 1 :3d} /{ epochs } ] { tr_loss_str :>35} "
186+ )
177187 pbar .update (4 )
178188 val_l2_min = 1e4
179189 val_l2 = eval_epoch_ns (
@@ -214,13 +224,19 @@ def main(args):
214224 )
215225 test_loader = DataLoader (test_dataset , batch_size = 1 , shuffle = False )
216226 torch .cuda .empty_cache ()
217- model = SFNO (modes , modes , modes_t , width , beta ,
218- num_spectral_layers = num_layers ,
219- spatial_padding = spatial_padding ,
220- activation = activation ,
221- pe_trainable = pe_trainable ,
222- spatial_random_feats = spatial_random_feats ,
223- lift_activation = lift_activation ).to (device )
227+ model = SFNO (
228+ modes ,
229+ modes ,
230+ modes_t ,
231+ width ,
232+ beta = beta ,
233+ num_spectral_layers = num_layers ,
234+ spatial_padding = spatial_padding ,
235+ activation = activation ,
236+ pe_trainable = pe_trainable ,
237+ spatial_random_feats = spatial_random_feats ,
238+ lift_activation = lift_activation ,
239+ ).to (device )
224240 model .load_state_dict (torch .load (path_model ))
225241 logger .info (f"Loaded model from { path_model } " )
226242 eval_metric = SobolevLoss (n_grid = n_test , norm_order = norm_order , relative = True )
@@ -238,20 +254,21 @@ def main(args):
238254 if args .demo_plots > 0 :
239255 try :
240256 from visualizations import plot_contour_trajectory
257+
241258 idx = np .random .randint (0 , args .num_test_samples )
242259 im1 = plot_contour_trajectory (
243260 preds [idx ],
244261 num_snapshots = args .demo_plots ,
245262 T_start = args .time_warmup ,
246263 dt = args .dt ,
247- title = "SFNO predictions"
264+ title = "SFNO predictions" ,
248265 )
249266 im2 = plot_contour_trajectory (
250267 gt_solns [idx ],
251268 num_snapshots = args .demo_plots ,
252269 T_start = args .time_warmup ,
253270 dt = args .dt ,
254- title = "Ground truth generated by IMEX"
271+ title = "Ground truth generated by IMEX" ,
255272 )
256273 plt .show ()
257274 except Exception as e :
@@ -294,4 +311,4 @@ def main(args):
294311 parser .add_argument ("--demo-plots" , type = int , default = 0 )
295312
296313 args = parser .parse_args ()
297- main (args )
314+ main (args )
0 commit comments