1212 ConvolutionalAutoencoder ,
1313 ConvolutionalVAE ,
1414 DenoisingAutoencoder ,
15+ SparseAutoencoder ,
1516 VariationalAutoencoder ,
1617 DenoisingConvolutionalAutoencoder ,
18+ SparseConvolutionalAutoencoder
1719)
1820
1921from settings import settings
@@ -26,12 +28,14 @@ def get_model_by_type(ae_type=None, input_dim=None, encoding_dim=None, device=No
2628 models = {
2729 'ae' : lambda : Autoencoder (input_dim , encoding_dim ),
2830 'dae' : lambda : DenoisingAutoencoder (input_dim , encoding_dim ),
31+ 'sparse' : lambda : SparseAutoencoder (input_dim , encoding_dim ),
2932 'vae' : VariationalAutoencoder ,
3033 'conv' : ConvolutionalAutoencoder ,
3134 'conv_dae' : DenoisingConvolutionalAutoencoder ,
3235 'conv_vae' : ConvolutionalVAE ,
36+ 'conv_sparse' : SparseConvolutionalAutoencoder ,
3337 }
34-
38+
3539 if ae_type is None :
3640 return list (models .keys ())
3741
@@ -81,15 +85,15 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
8185 model = get_model_by_type (ae_type , input_dim , encoding_dim , device )
8286 optimizer = torch .optim .Adam (model .parameters ())
8387
84- start_epoch = 0
85- if os .path .exists (settings .PATH_SAVED_MODEL ):
86- model , optimizer , start_epoch = load_checkpoint (
87- model , optimizer , settings .PATH_SAVED_MODEL , device
88- )
89- print (f"Loaded checkpoint and continuing training from epoch { start_epoch } ." )
90-
9188 try :
9289 if not load_trained_model :
90+ start_epoch = 1
91+ if os .path .exists (settings .PATH_SAVED_MODEL ):
92+ model , optimizer , start_epoch = load_checkpoint (
93+ model , optimizer , settings .PATH_SAVED_MODEL , device
94+ )
95+ print (f"Loaded checkpoint and continuing training from epoch { start_epoch } ." )
96+
9397 start_time = time .time ()
9498
9599 train_autoencoder (
@@ -100,7 +104,8 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
100104 device = device ,
101105 start_epoch = start_epoch ,
102106 optimizer = optimizer ,
103- save_checkpoint = save_checkpoint
107+ save_checkpoint = save_checkpoint ,
108+ ae_type = ae_type
104109 )
105110
106111 elapsed_time = utils .format_time (time .time () - start_time )
@@ -112,12 +117,12 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
112117
113118 if not test_mode :
114119 valid_dataloader = get_dataloader (settings .VALID_DATA_PATH , batch_size , resolution )
115- avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device , ae_type )
120+ avg_valid_loss = evaluate_autoencoder (model , valid_dataloader , device )
116121 print (f"\n Average validation loss: { avg_valid_loss :.4f} \n " )
117122
118123 visualize_reconstructions (
119124 model , valid_dataloader , num_samples = 10 ,
120- device = device , ae_type = ae_type , resolution = resolution
125+ device = device , resolution = resolution
121126 )
122127
123128
0 commit comments