Skip to content

Commit d319381

Browse files
Merge pull request #16 from renan-siqueira/develop
Merge develop into main
2 parents df866a3 + 6522923 commit d319381

File tree

10 files changed

+180
-35
lines changed

10 files changed

+180
-35
lines changed

json/params.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"batch_size": 32,
33
"resolution": 64,
44
"encoding_dim": 16,
5-
"num_epochs": 1000,
5+
"num_epochs": 25,
66
"learning_rate": 0.001,
7-
"ae_type": "conv",
7+
"ae_type": "conv_dae",
88
"save_checkpoint": null
99
}

models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .autoencoder import Autoencoder
2+
from .autoencoder_dae import DenoisingAutoencoder
3+
from .autoencoder_vae import VariationalAutoencoder
24
from .convolutional_autoencoder import ConvolutionalAutoencoder
5+
from .convolutional_dae import DenoisingConvolutionalAutoencoder
36
from .convolutional_vae import ConvolutionalVAE
4-
from .variational_autoencoder import VariationalAutoencoder

models/autoencoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
class Autoencoder(nn.Module):
55
def __init__(self, input_dim, encoding_dim):
6-
print('***** Autoencoder input_dim:', input_dim)
76
super(Autoencoder, self).__init__()
7+
8+
self.model_structure = 'linear'
9+
self.model_variant = 'vanilla'
10+
811
self.encoder = nn.Sequential(
912
nn.Linear(input_dim, 1024),
1013
nn.ReLU(),
@@ -18,6 +21,7 @@ def __init__(self, input_dim, encoding_dim):
1821
nn.ReLU(),
1922
nn.Linear(64, encoding_dim)
2023
)
24+
2125
self.decoder = nn.Sequential(
2226
nn.Linear(encoding_dim, 64),
2327
nn.ReLU(),

models/autoencoder_dae.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class DenoisingAutoencoder(nn.Module):
6+
def __init__(self, input_dim, encoding_dim):
7+
super(DenoisingAutoencoder, self).__init__()
8+
9+
self.model_structure = 'linear'
10+
self.model_variant = 'vanilla'
11+
12+
self.encoder = nn.Sequential(
13+
nn.Linear(input_dim, 1024),
14+
nn.ReLU(),
15+
nn.Linear(1024, 512),
16+
nn.ReLU(),
17+
nn.Linear(512, 256),
18+
nn.ReLU(),
19+
nn.Linear(256, 128),
20+
nn.ReLU(),
21+
nn.Linear(128, 64),
22+
nn.ReLU(),
23+
nn.Linear(64, encoding_dim)
24+
)
25+
26+
self.decoder = nn.Sequential(
27+
nn.Linear(encoding_dim, 64),
28+
nn.ReLU(),
29+
nn.Linear(64, 128),
30+
nn.ReLU(),
31+
nn.Linear(128, 256),
32+
nn.ReLU(),
33+
nn.Linear(256, 512),
34+
nn.ReLU(),
35+
nn.Linear(512, 1024),
36+
nn.ReLU(),
37+
nn.Linear(1024, input_dim),
38+
nn.Sigmoid()
39+
)
40+
41+
def forward(self, x):
42+
noise = torch.randn_like(x) * 0.1
43+
x_corrupted = x + noise
44+
45+
x_encoded = self.encoder(x_corrupted)
46+
x_decoded = self.decoder(x_encoded)
47+
48+
return x_decoded
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class VariationalAutoencoder(nn.Module):
77
def __init__(self, encoding_dim=128):
88
super(VariationalAutoencoder, self).__init__()
99

10+
self.model_structure = 'linear'
11+
self.model_variant = 'vae'
12+
1013
# Encoder
1114
self.enc1 = nn.Linear(3 * 64 * 64, 512)
1215
self.enc2 = nn.Linear(512, 256)

models/convolutional_autoencoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class ConvolutionalAutoencoder(nn.Module):
77
def __init__(self):
88
super(ConvolutionalAutoencoder, self).__init__()
99

10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vanilla'
12+
1013
# Encoder
1114
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
1215
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)

models/convolutional_dae.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class DenoisingConvolutionalAutoencoder(nn.Module):
7+
def __init__(self):
8+
super(DenoisingConvolutionalAutoencoder, self).__init__()
9+
10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vanilla'
12+
13+
# Encoder
14+
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
15+
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
16+
self.enc2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
17+
self.enc3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
18+
self.enc4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
19+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
20+
21+
# Decoder
22+
self.dec0 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
23+
self.dec1 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
24+
self.dec2 = nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2)
25+
self.dec3 = nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2)
26+
self.dec4 = nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2)
27+
28+
def forward(self, x):
29+
noise = torch.randn_like(x) * 0.1
30+
x_corrupted = x + noise
31+
32+
# Encoder
33+
x, _ = self.pool(F.relu(self.enc0(x_corrupted)))
34+
x, _ = self.pool(F.relu(self.enc1(x)))
35+
x, _ = self.pool(F.relu(self.enc2(x)))
36+
x, _ = self.pool(F.relu(self.enc3(x)))
37+
x, _ = self.pool(F.relu(self.enc4(x)))
38+
39+
# Decoder
40+
x = F.relu(self.dec0(x))
41+
x = F.relu(self.dec1(x))
42+
x = F.relu(self.dec2(x))
43+
x = F.relu(self.dec3(x))
44+
x = torch.sigmoid(self.dec4(x))
45+
46+
return x

models/convolutional_vae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class ConvolutionalVAE(nn.Module):
77
def __init__(self):
88
super(ConvolutionalVAE, self).__init__()
99

10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'vae'
12+
1013
# Encoder
1114
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
1215
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

run.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,47 @@
22
import json
33
import random
44
import time
5+
import argparse
56

67
import numpy as np
78
import torch
89

9-
from models import Autoencoder, ConvolutionalAutoencoder, ConvolutionalVAE, VariationalAutoencoder
10+
from models import (
11+
Autoencoder,
12+
ConvolutionalAutoencoder,
13+
ConvolutionalVAE,
14+
DenoisingAutoencoder,
15+
VariationalAutoencoder,
16+
DenoisingConvolutionalAutoencoder,
17+
)
18+
1019
from settings import settings
1120
from utils.dataloader import get_dataloader
1221
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
1322
from utils import utils
1423

1524

16-
def set_seed(seed=42):
25+
def get_model_by_type(ae_type=None, input_dim=None, encoding_dim=None, device=None):
26+
models = {
27+
'ae': lambda: Autoencoder(input_dim, encoding_dim),
28+
'dae': lambda: DenoisingAutoencoder(input_dim, encoding_dim),
29+
'vae': VariationalAutoencoder,
30+
'conv': ConvolutionalAutoencoder,
31+
'conv_dae': DenoisingConvolutionalAutoencoder,
32+
'conv_vae': ConvolutionalVAE,
33+
}
34+
35+
if ae_type is None:
36+
return list(models.keys())
37+
38+
if ae_type not in models:
39+
raise ValueError(f"Unknown AE type: {ae_type}")
40+
41+
model = models[ae_type]()
42+
return model.to(device)
43+
44+
45+
def set_seed(seed):
1746
torch.manual_seed(seed)
1847
torch.cuda.manual_seed_all(seed)
1948
np.random.seed(seed)
@@ -28,35 +57,28 @@ def load_params(path):
2857
return params
2958

3059

31-
def main(load_trained_model):
60+
def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
3261
set_seed(1)
3362
params = load_params(settings.PATH_PARAMS_JSON)
3463

3564
batch_size = params["batch_size"]
3665
resolution = params["resolution"]
3766
encoding_dim = params["encoding_dim"]
38-
num_epochs = params["num_epochs"]
3967
learning_rate = params.get("learning_rate", 0.001)
40-
ae_type = params["ae_type"]
4168
save_checkpoint = params["save_checkpoint"]
4269

70+
if not ae_type:
71+
ae_type = params["ae_type"]
72+
num_epochs = params["num_epochs"]
73+
test_mode = False
74+
4375
# Calculate input_dim based on resolution
4476
input_dim = 3 * resolution * resolution
4577

4678
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4779
dataloader = get_dataloader(settings.DATA_PATH, batch_size, resolution)
4880

49-
if ae_type == 'ae':
50-
model = Autoencoder(input_dim, encoding_dim).to(device)
51-
elif ae_type == 'conv':
52-
model = ConvolutionalAutoencoder().to(device)
53-
elif ae_type == 'vae':
54-
model = VariationalAutoencoder().to(device)
55-
elif ae_type == 'conv_vae':
56-
model = ConvolutionalVAE().to(device)
57-
else:
58-
raise ValueError(f"Unknown AE type: {ae_type}")
59-
81+
model = get_model_by_type(ae_type, input_dim, encoding_dim, device)
6082
optimizer = torch.optim.Adam(model.parameters())
6183

6284
start_epoch = 0
@@ -78,25 +100,39 @@ def main(load_trained_model):
78100
device=device,
79101
start_epoch=start_epoch,
80102
optimizer=optimizer,
81-
ae_type=ae_type,
82103
save_checkpoint=save_checkpoint
83104
)
84105

85106
elapsed_time = utils.format_time(time.time() - start_time)
86107
print(f"\nTraining took {elapsed_time}")
87108
print(f"Training complete up to epoch {num_epochs}!")
88-
109+
89110
except KeyboardInterrupt:
90111
print("\nTraining interrupted by user.")
91112

92-
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
93-
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
94-
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
113+
if not test_mode:
114+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
115+
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
116+
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
95117

96-
visualize_reconstructions(
97-
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type, resolution=resolution
98-
)
118+
visualize_reconstructions(
119+
model, valid_dataloader, num_samples=10,
120+
device=device, ae_type=ae_type, resolution=resolution
121+
)
99122

100123

101124
if __name__ == "__main__":
102-
main(False)
125+
parser = argparse.ArgumentParser(description='Training and testing autoencoders.')
126+
parser.add_argument(
127+
'--test', action='store_true', help='Run the test routine for all autoencoders.'
128+
)
129+
130+
args = parser.parse_args()
131+
132+
if args.test:
133+
ae_types = get_model_by_type()
134+
for ae_type in ae_types:
135+
print(f"\n===== Training {ae_type} =====\n")
136+
main(load_trained_model=False, ae_type=ae_type)
137+
else:
138+
main(load_trained_model=False)

utils/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_checkpoint(model, optimizer, path, device):
3636

3737
def train_autoencoder(
3838
model, dataloader, num_epochs, learning_rate, device,
39-
start_epoch, optimizer, ae_type, save_checkpoint
39+
start_epoch, optimizer, save_checkpoint
4040
):
4141
criterion = nn.MSELoss()
4242
if optimizer is None:
@@ -46,10 +46,10 @@ def train_autoencoder(
4646
for data in dataloader:
4747
img = data.to(device)
4848

49-
if ae_type not in ['conv', 'conv_vae']:
49+
if model.model_structure == 'linear':
5050
img = img.view(img.size(0), -1)
5151

52-
if ae_type in ['vae', 'conv_vae']:
52+
if model.model_variant == 'vae':
5353
recon_x, mu, log_var = model(img)
5454
loss = loss_function_vae(recon_x, img, mu, log_var)
5555
else:
@@ -81,7 +81,7 @@ def evaluate_autoencoder(model, dataloader, device, ae_type):
8181
for data in dataloader:
8282
img = data.to(device)
8383

84-
if ae_type not in ['conv', 'conv_vae']:
84+
if model.model_structure == 'linear':
8585
img = img.view(img.size(0), -1)
8686

8787
if ae_type in ['vae', 'conv_vae']:
@@ -99,10 +99,10 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
9999
samples = next(iter(dataloader))
100100
samples = samples[:num_samples].to(device)
101101

102-
if ae_type not in ['conv', 'conv_vae']:
102+
if model.model_structure == 'linear':
103103
samples = samples.view(samples.size(0), -1)
104104

105-
if ae_type in ['vae', 'conv_vae']:
105+
if model.model_variant == 'vae':
106106
reconstructions, _, _ = model(samples)
107107
else:
108108
reconstructions = model(samples)

0 commit comments

Comments
 (0)