22import json
33import random
44import time
5+ import argparse
56
67import numpy as np
78import torch
1213 ConvolutionalVAE ,
1314 DenoisingAutoencoder ,
1415 VariationalAutoencoder ,
16+ DenoisingConvolutionalAutoencoder ,
1517)
1618
1719from settings import settings
2022from utils import utils
2123
2224
23- def get_model_by_type (ae_type , input_dim , encoding_dim , device ):
25+ def get_model_by_type (ae_type = None , input_dim = None , encoding_dim = None , device = None ):
2426 models = {
2527 'ae' : lambda : Autoencoder (input_dim , encoding_dim ),
26- 'conv' : ConvolutionalAutoencoder ,
27- 'conv_vae' : ConvolutionalVAE ,
2828 'dae' : lambda : DenoisingAutoencoder (input_dim , encoding_dim ),
2929 'vae' : VariationalAutoencoder ,
30+ 'conv' : ConvolutionalAutoencoder ,
31+ 'conv_dae' : DenoisingConvolutionalAutoencoder ,
32+ 'conv_vae' : ConvolutionalVAE ,
3033 }
34+
35+ if ae_type is None :
36+ return list (models .keys ())
3137
3238 if ae_type not in models :
3339 raise ValueError (f"Unknown AE type: { ae_type } " )
@@ -51,18 +57,21 @@ def load_params(path):
5157 return params
5258
5359
54- def main (load_trained_model ):
60+ def main (load_trained_model , ae_type = None , num_epochs = 5 , test_mode = True ):
5561 set_seed (1 )
5662 params = load_params (settings .PATH_PARAMS_JSON )
5763
5864 batch_size = params ["batch_size" ]
5965 resolution = params ["resolution" ]
6066 encoding_dim = params ["encoding_dim" ]
61- num_epochs = params ["num_epochs" ]
6267 learning_rate = params .get ("learning_rate" , 0.001 )
63- ae_type = params ["ae_type" ]
6468 save_checkpoint = params ["save_checkpoint" ]
6569
70+ if not ae_type :
71+ ae_type = params ["ae_type" ]
72+ num_epochs = params ["num_epochs" ]
73+ test_mode = False
74+
6675 # Calculate input_dim based on resolution
6776 input_dim = 3 * resolution * resolution
6877
@@ -91,25 +100,39 @@ def main(load_trained_model):
91100 device = device ,
92101 start_epoch = start_epoch ,
93102 optimizer = optimizer ,
94- ae_type = ae_type ,
95103 save_checkpoint = save_checkpoint
96104 )
97105
98106 elapsed_time = utils .format_time (time .time () - start_time )
99107 print (f"\n Training took { elapsed_time } " )
100108 print (f"Training complete up to epoch { num_epochs } !" )
101-
109+
102110 except KeyboardInterrupt :
103111 print ("\n Training interrupted by user." )
104112
105- valid_dataloader = get_dataloader (settings .VALID_DATA_PATH , batch_size , resolution )
106- avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device , ae_type )
107- print (f"\n Average validation loss: { avg_valid_loss :.4f} \n " )
113+ if not test_mode :
114+ valid_dataloader = get_dataloader (settings .VALID_DATA_PATH , batch_size , resolution )
115+ avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device , ae_type )
116+ print (f"\n Average validation loss: { avg_valid_loss :.4f} \n " )
108117
109- visualize_reconstructions (
110- model , valid_dataloader , num_samples = 10 , device = device , ae_type = ae_type , resolution = resolution
111- )
118+ visualize_reconstructions (
119+ model , valid_dataloader , num_samples = 10 ,
120+ device = device , ae_type = ae_type , resolution = resolution
121+ )
112122
113123
114124if __name__ == "__main__" :
115- main (False )
125+ parser = argparse .ArgumentParser (description = 'Training and testing autoencoders.' )
126+ parser .add_argument (
127+ '--test' , action = 'store_true' , help = 'Run the test routine for all autoencoders.'
128+ )
129+
130+ args = parser .parse_args ()
131+
132+ if args .test :
133+ ae_types = get_model_by_type ()
134+ for ae_type in ae_types :
135+ print (f"\n ===== Training { ae_type } =====\n " )
136+ main (load_trained_model = False , ae_type = ae_type )
137+ else :
138+ main (load_trained_model = False )
0 commit comments