Skip to content

Commit 56346ce

Browse files
feat: improvements on models and trainer. Create test mode to run with all models
1 parent 7bafb58 commit 56346ce

File tree

10 files changed

+111
-26
lines changed

10 files changed

+111
-26
lines changed

json/params.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
"encoding_dim": 16,
55
"num_epochs": 25,
66
"learning_rate": 0.001,
7-
"ae_type": "ae",
7+
"ae_type": "conv_dae",
88
"save_checkpoint": null
99
}

models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .autoencoder import Autoencoder
2+
from .autoencoder_dae import DenoisingAutoencoder
3+
from .autoencoder_vae import VariationalAutoencoder
24
from .convolutional_autoencoder import ConvolutionalAutoencoder
5+
from .convolutional_dae import DenoisingConvolutionalAutoencoder
36
from .convolutional_vae import ConvolutionalVAE
4-
from .denoising_autoencoder import DenoisingAutoencoder
5-
from .variational_autoencoder import VariationalAutoencoder

models/autoencoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
class Autoencoder(nn.Module):
55
def __init__(self, input_dim, encoding_dim):
6-
print('***** Autoencoder input_dim:', input_dim)
76
super(Autoencoder, self).__init__()
7+
8+
self.model_structure = 'linear'
9+
self.model_variant = 'vanilla'
10+
811
self.encoder = nn.Sequential(
912
nn.Linear(input_dim, 1024),
1013
nn.ReLU(),
@@ -18,6 +21,7 @@ def __init__(self, input_dim, encoding_dim):
1821
nn.ReLU(),
1922
nn.Linear(64, encoding_dim)
2023
)
24+
2125
self.decoder = nn.Sequential(
2226
nn.Linear(encoding_dim, 64),
2327
nn.ReLU(),
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
class DenoisingAutoencoder(nn.Module):
66
def __init__(self, input_dim, encoding_dim):
7-
print('***** Denoising Autoencoder input_dim:', input_dim)
87
super(DenoisingAutoencoder, self).__init__()
98

9+
self.model_structure = 'linear'
10+
self.model_variant = 'vanilla'
11+
1012
self.encoder = nn.Sequential(
1113
nn.Linear(input_dim, 1024),
1214
nn.ReLU(),
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class VariationalAutoencoder(nn.Module):
77
def __init__(self, encoding_dim=128):
88
super(VariationalAutoencoder, self).__init__()
99

10+
self.model_structure = 'linear'
11+
self.model_variant = 'vae'
12+
1013
# Encoder
1114
self.enc1 = nn.Linear(3 * 64 * 64, 512)
1215
self.enc2 = nn.Linear(512, 256)

models/convolutional_autoencoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class ConvolutionalAutoencoder(nn.Module):
77
def __init__(self):
88
super(ConvolutionalAutoencoder, self).__init__()
99

10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vanilla'
12+
1013
# Encoder
1114
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
1215
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)

models/convolutional_dae.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class DenoisingConvolutionalAutoencoder(nn.Module):
7+
def __init__(self):
8+
super(DenoisingConvolutionalAutoencoder, self).__init__()
9+
10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vanilla'
12+
13+
# Encoder
14+
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
15+
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
16+
self.enc2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
17+
self.enc3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
18+
self.enc4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
19+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
20+
21+
# Decoder
22+
self.dec0 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
23+
self.dec1 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
24+
self.dec2 = nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2)
25+
self.dec3 = nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2)
26+
self.dec4 = nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2)
27+
28+
def forward(self, x):
29+
noise = torch.randn_like(x) * 0.1
30+
x_corrupted = x + noise
31+
32+
# Encoder
33+
x, _ = self.pool(F.relu(self.enc0(x_corrupted)))
34+
x, _ = self.pool(F.relu(self.enc1(x)))
35+
x, _ = self.pool(F.relu(self.enc2(x)))
36+
x, _ = self.pool(F.relu(self.enc3(x)))
37+
x, _ = self.pool(F.relu(self.enc4(x)))
38+
39+
# Decoder
40+
x = F.relu(self.dec0(x))
41+
x = F.relu(self.dec1(x))
42+
x = F.relu(self.dec2(x))
43+
x = F.relu(self.dec3(x))
44+
x = torch.sigmoid(self.dec4(x))
45+
46+
return x

models/convolutional_vae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class ConvolutionalVAE(nn.Module):
77
def __init__(self):
88
super(ConvolutionalVAE, self).__init__()
99

10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vae'
12+
1013
# Encoder
1114
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
1215
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

run.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import random
44
import time
5+
import argparse
56

67
import numpy as np
78
import torch
@@ -12,6 +13,7 @@
1213
ConvolutionalVAE,
1314
DenoisingAutoencoder,
1415
VariationalAutoencoder,
16+
DenoisingConvolutionalAutoencoder,
1517
)
1618

1719
from settings import settings
@@ -20,14 +22,18 @@
2022
from utils import utils
2123

2224

23-
def get_model_by_type(ae_type, input_dim, encoding_dim, device):
25+
def get_model_by_type(ae_type=None, input_dim=None, encoding_dim=None, device=None):
2426
models = {
2527
'ae': lambda: Autoencoder(input_dim, encoding_dim),
26-
'conv': ConvolutionalAutoencoder,
27-
'conv_vae': ConvolutionalVAE,
2828
'dae': lambda: DenoisingAutoencoder(input_dim, encoding_dim),
2929
'vae': VariationalAutoencoder,
30+
'conv': ConvolutionalAutoencoder,
31+
'conv_dae': DenoisingConvolutionalAutoencoder,
32+
'conv_vae': ConvolutionalVAE,
3033
}
34+
35+
if ae_type is None:
36+
return list(models.keys())
3137

3238
if ae_type not in models:
3339
raise ValueError(f"Unknown AE type: {ae_type}")
@@ -51,18 +57,21 @@ def load_params(path):
5157
return params
5258

5359

54-
def main(load_trained_model):
60+
def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
5561
set_seed(1)
5662
params = load_params(settings.PATH_PARAMS_JSON)
5763

5864
batch_size = params["batch_size"]
5965
resolution = params["resolution"]
6066
encoding_dim = params["encoding_dim"]
61-
num_epochs = params["num_epochs"]
6267
learning_rate = params.get("learning_rate", 0.001)
63-
ae_type = params["ae_type"]
6468
save_checkpoint = params["save_checkpoint"]
6569

70+
if not ae_type:
71+
ae_type = params["ae_type"]
72+
num_epochs = params["num_epochs"]
73+
test_mode = False
74+
6675
# Calculate input_dim based on resolution
6776
input_dim = 3 * resolution * resolution
6877

@@ -91,25 +100,39 @@ def main(load_trained_model):
91100
device=device,
92101
start_epoch=start_epoch,
93102
optimizer=optimizer,
94-
ae_type=ae_type,
95103
save_checkpoint=save_checkpoint
96104
)
97105

98106
elapsed_time = utils.format_time(time.time() - start_time)
99107
print(f"\nTraining took {elapsed_time}")
100108
print(f"Training complete up to epoch {num_epochs}!")
101-
109+
102110
except KeyboardInterrupt:
103111
print("\nTraining interrupted by user.")
104112

105-
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
106-
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
107-
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
113+
if not test_mode:
114+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
115+
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
116+
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
108117

109-
visualize_reconstructions(
110-
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type, resolution=resolution
111-
)
118+
visualize_reconstructions(
119+
model, valid_dataloader, num_samples=10,
120+
device=device, ae_type=ae_type, resolution=resolution
121+
)
112122

113123

114124
if __name__ == "__main__":
115-
main(False)
125+
parser = argparse.ArgumentParser(description='Training and testing autoencoders.')
126+
parser.add_argument(
127+
'--test', action='store_true', help='Run the test routine for all autoencoders.'
128+
)
129+
130+
args = parser.parse_args()
131+
132+
if args.test:
133+
ae_types = get_model_by_type()
134+
for ae_type in ae_types:
135+
print(f"\n===== Training {ae_type} =====\n")
136+
main(load_trained_model=False, ae_type=ae_type)
137+
else:
138+
main(load_trained_model=False)

utils/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_checkpoint(model, optimizer, path, device):
3636

3737
def train_autoencoder(
3838
model, dataloader, num_epochs, learning_rate, device,
39-
start_epoch, optimizer, ae_type, save_checkpoint
39+
start_epoch, optimizer, save_checkpoint
4040
):
4141
criterion = nn.MSELoss()
4242
if optimizer is None:
@@ -46,10 +46,10 @@ def train_autoencoder(
4646
for data in dataloader:
4747
img = data.to(device)
4848

49-
if ae_type not in ['conv', 'conv_vae']:
49+
if model.model_structure == 'linear':
5050
img = img.view(img.size(0), -1)
5151

52-
if ae_type in ['vae', 'conv_vae']:
52+
if model.model_variant == 'vae':
5353
recon_x, mu, log_var = model(img)
5454
loss = loss_function_vae(recon_x, img, mu, log_var)
5555
else:
@@ -81,7 +81,7 @@ def evaluate_autoencoder(model, dataloader, device, ae_type):
8181
for data in dataloader:
8282
img = data.to(device)
8383

84-
if ae_type not in ['conv', 'conv_vae']:
84+
if model.model_structure == 'linear':
8585
img = img.view(img.size(0), -1)
8686

8787
if ae_type in ['vae', 'conv_vae']:
@@ -99,10 +99,10 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
9999
samples = next(iter(dataloader))
100100
samples = samples[:num_samples].to(device)
101101

102-
if ae_type not in ['conv', 'conv_vae']:
102+
if model.model_structure == 'linear':
103103
samples = samples.view(samples.size(0), -1)
104104

105-
if ae_type in ['vae', 'conv_vae']:
105+
if model.model_variant == 'vae':
106106
reconstructions, _, _ = model(samples)
107107
else:
108108
reconstructions = model(samples)

0 commit comments

Comments
 (0)