22import json
33import random
44import time
5+ import argparse
56
67import numpy as np
78import torch
89
9- from models import Autoencoder , ConvolutionalAutoencoder , ConvolutionalVAE , VariationalAutoencoder
10+ from models import (
11+ Autoencoder ,
12+ ConvolutionalAutoencoder ,
13+ ConvolutionalVAE ,
14+ DenoisingAutoencoder ,
15+ VariationalAutoencoder ,
16+ DenoisingConvolutionalAutoencoder ,
17+ )
18+
1019from settings import settings
1120from utils .dataloader import get_dataloader
1221from utils .trainer import train_autoencoder , visualize_reconstructions , load_checkpoint , evaluate_autoencoder
1322from utils import utils
1423
1524
16- def set_seed (seed = 42 ):
25+ def get_model_by_type (ae_type = None , input_dim = None , encoding_dim = None , device = None ):
26+ models = {
27+ 'ae' : lambda : Autoencoder (input_dim , encoding_dim ),
28+ 'dae' : lambda : DenoisingAutoencoder (input_dim , encoding_dim ),
29+ 'vae' : VariationalAutoencoder ,
30+ 'conv' : ConvolutionalAutoencoder ,
31+ 'conv_dae' : DenoisingConvolutionalAutoencoder ,
32+ 'conv_vae' : ConvolutionalVAE ,
33+ }
34+
35+ if ae_type is None :
36+ return list (models .keys ())
37+
38+ if ae_type not in models :
39+ raise ValueError (f"Unknown AE type: { ae_type } " )
40+
41+ model = models [ae_type ]()
42+ return model .to (device )
43+
44+
45+ def set_seed (seed ):
1746 torch .manual_seed (seed )
1847 torch .cuda .manual_seed_all (seed )
1948 np .random .seed (seed )
@@ -28,35 +57,28 @@ def load_params(path):
2857 return params
2958
3059
31- def main (load_trained_model ):
60+ def main (load_trained_model , ae_type = None , num_epochs = 5 , test_mode = True ):
3261 set_seed (1 )
3362 params = load_params (settings .PATH_PARAMS_JSON )
3463
3564 batch_size = params ["batch_size" ]
3665 resolution = params ["resolution" ]
3766 encoding_dim = params ["encoding_dim" ]
38- num_epochs = params ["num_epochs" ]
3967 learning_rate = params .get ("learning_rate" , 0.001 )
40- ae_type = params ["ae_type" ]
4168 save_checkpoint = params ["save_checkpoint" ]
4269
70+ if not ae_type :
71+ ae_type = params ["ae_type" ]
72+ num_epochs = params ["num_epochs" ]
73+ test_mode = False
74+
4375 # Calculate input_dim based on resolution
4476 input_dim = 3 * resolution * resolution
4577
4678 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4779 dataloader = get_dataloader (settings .DATA_PATH , batch_size , resolution )
4880
49- if ae_type == 'ae' :
50- model = Autoencoder (input_dim , encoding_dim ).to (device )
51- elif ae_type == 'conv' :
52- model = ConvolutionalAutoencoder ().to (device )
53- elif ae_type == 'vae' :
54- model = VariationalAutoencoder ().to (device )
55- elif ae_type == 'conv_vae' :
56- model = ConvolutionalVAE ().to (device )
57- else :
58- raise ValueError (f"Unknown AE type: { ae_type } " )
59-
81+ model = get_model_by_type (ae_type , input_dim , encoding_dim , device )
6082 optimizer = torch .optim .Adam (model .parameters ())
6183
6284 start_epoch = 0
@@ -78,25 +100,39 @@ def main(load_trained_model):
78100 device = device ,
79101 start_epoch = start_epoch ,
80102 optimizer = optimizer ,
81- ae_type = ae_type ,
82103 save_checkpoint = save_checkpoint
83104 )
84105
85106 elapsed_time = utils .format_time (time .time () - start_time )
86107 print (f"\n Training took { elapsed_time } " )
87108 print (f"Training complete up to epoch { num_epochs } !" )
88-
109+
89110 except KeyboardInterrupt :
90111 print ("\n Training interrupted by user." )
91112
92- valid_dataloader = get_dataloader (settings .VALID_DATA_PATH , batch_size , resolution )
93- avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device , ae_type )
94- print (f"\n Average validation loss: { avg_valid_loss :.4f} \n " )
113+ if not test_mode :
114+ valid_dataloader = get_dataloader (settings .VALID_DATA_PATH , batch_size , resolution )
115+ avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device , ae_type )
116+ print (f"\n Average validation loss: { avg_valid_loss :.4f} \n " )
95117
96- visualize_reconstructions (
97- model , valid_dataloader , num_samples = 10 , device = device , ae_type = ae_type , resolution = resolution
98- )
118+ visualize_reconstructions (
119+ model , valid_dataloader , num_samples = 10 ,
120+ device = device , ae_type = ae_type , resolution = resolution
121+ )
99122
100123
101124if __name__ == "__main__" :
102- main (False )
125+ parser = argparse .ArgumentParser (description = 'Training and testing autoencoders.' )
126+ parser .add_argument (
127+ '--test' , action = 'store_true' , help = 'Run the test routine for all autoencoders.'
128+ )
129+
130+ args = parser .parse_args ()
131+
132+ if args .test :
133+ ae_types = get_model_by_type ()
134+ for ae_type in ae_types :
135+ print (f"\n ===== Training { ae_type } =====\n " )
136+ main (load_trained_model = False , ae_type = ae_type )
137+ else :
138+ main (load_trained_model = False )
0 commit comments