Skip to content

Commit d482ecf

Browse files
Merge pull request #12 from renan-siqueira/feature/UpgradeArchitecture
Upgrade Architecture
2 parents 69dac33 + 8fb66eb commit d482ecf

File tree

13 files changed

+278
-194
lines changed

13 files changed

+278
-194
lines changed

data/train/__delete_me__

Lines changed: 0 additions & 1 deletion
This file was deleted.

data/valid/__delete_me__

Lines changed: 0 additions & 1 deletion
This file was deleted.

json/params.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"batch_size": 32,
3+
"resolution": 64,
4+
"encoding_dim": 16,
5+
"num_epochs": 1000,
6+
"learning_rate": 0.001,
7+
"ae_type": "ae",
8+
"save_checkpoint": null
9+
}

models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .autoencoder import Autoencoder
2+
from .convolutional_autoencoder import ConvolutionalAutoencoder
3+
from .convolutional_vae import ConvolutionalVAE
4+
from .variational_autoencoder import VariationalAutoencoder

models/autoencoder.py

Lines changed: 15 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
import torch
21
import torch.nn as nn
3-
import torch.nn.functional as F
42

53

6-
# Autoencoder Linear
74
class Autoencoder(nn.Module):
85
def __init__(self, input_dim, encoding_dim):
6+
print('***** Autoencoder input_dim:', input_dim)
97
super(Autoencoder, self).__init__()
108
self.encoder = nn.Sequential(
11-
nn.Linear(input_dim, 128),
9+
nn.Linear(input_dim, 1024),
10+
nn.ReLU(),
11+
nn.Linear(1024, 512),
12+
nn.ReLU(),
13+
nn.Linear(512, 256),
14+
nn.ReLU(),
15+
nn.Linear(256, 128),
1216
nn.ReLU(),
1317
nn.Linear(128, 64),
1418
nn.ReLU(),
@@ -19,135 +23,17 @@ def __init__(self, input_dim, encoding_dim):
1923
nn.ReLU(),
2024
nn.Linear(64, 128),
2125
nn.ReLU(),
22-
nn.Linear(128, input_dim),
26+
nn.Linear(128, 256),
27+
nn.ReLU(),
28+
nn.Linear(256, 512),
29+
nn.ReLU(),
30+
nn.Linear(512, 1024),
31+
nn.ReLU(),
32+
nn.Linear(1024, input_dim),
2333
nn.Sigmoid()
2434
)
2535

2636
def forward(self, x):
2737
x = self.encoder(x)
2838
x = self.decoder(x)
2939
return x
30-
31-
32-
# Autoencoder Convolucional
33-
class ConvolutionalAutoencoder(nn.Module):
34-
def __init__(self):
35-
super(ConvolutionalAutoencoder, self).__init__()
36-
37-
# Encoder
38-
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
39-
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
40-
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
41-
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
42-
43-
# Decoder
44-
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
45-
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
46-
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
47-
48-
def forward(self, x):
49-
x, idxs1 = self.pool(F.relu(self.enc1(x)))
50-
x, idxs2 = self.pool(F.relu(self.enc2(x)))
51-
x, idxs3 = self.pool(F.relu(self.enc3(x)))
52-
53-
x = F.relu(self.dec1(x))
54-
x = F.relu(self.dec2(x))
55-
x = torch.sigmoid(self.dec3(x))
56-
return x
57-
58-
59-
# Variational Autoencoder
60-
class VariationalAutoencoder(nn.Module):
61-
def __init__(self, encoding_dim=128):
62-
super(VariationalAutoencoder, self).__init__()
63-
64-
# Encoder
65-
self.enc1 = nn.Linear(3 * 64 * 64, 512)
66-
self.enc2 = nn.Linear(512, 256)
67-
self.enc3 = nn.Linear(256, encoding_dim)
68-
69-
# Latent space
70-
self.fc_mu = nn.Linear(encoding_dim, encoding_dim)
71-
self.fc_log_var = nn.Linear(encoding_dim, encoding_dim)
72-
73-
# Decoder
74-
self.dec1 = nn.Linear(encoding_dim, encoding_dim)
75-
self.dec2 = nn.Linear(encoding_dim, 256)
76-
self.dec3 = nn.Linear(256, 512)
77-
self.dec4 = nn.Linear(512, 3 * 64 * 64)
78-
79-
def reparameterize(self, mu, log_var):
80-
std = torch.exp(0.5 * log_var)
81-
eps = torch.randn_like(std)
82-
return mu + eps * std
83-
84-
def forward(self, x):
85-
x = F.relu(self.enc1(x))
86-
x = F.relu(self.enc2(x))
87-
x = F.relu(self.enc3(x))
88-
89-
mu = self.fc_mu(x)
90-
log_var = self.fc_log_var(x)
91-
z = self.reparameterize(mu, log_var)
92-
93-
x = F.relu(self.dec1(z))
94-
x = F.relu(self.dec2(x))
95-
x = F.relu(self.dec3(x))
96-
x = torch.sigmoid(self.dec4(x))
97-
98-
return x, mu, log_var
99-
100-
101-
# Convolucional Variational Autoencoder
102-
class ConvolutionalVAE(nn.Module):
103-
def __init__(self):
104-
super(ConvolutionalVAE, self).__init__()
105-
106-
# Encoder
107-
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
108-
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
109-
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
110-
self.pool = nn.MaxPool2d(2, 2)
111-
112-
self.fc_mu = nn.Linear(16 * 8 * 8, 128)
113-
self.fc_log_var = nn.Linear(16 * 8 * 8, 128)
114-
115-
# Decoder
116-
self.decoder_input = nn.Linear(128, 16 * 8 * 8)
117-
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1)
118-
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1)
119-
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1)
120-
121-
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
122-
123-
def reparameterize(self, mu, log_var):
124-
std = torch.exp(0.5 * log_var)
125-
eps = torch.randn_like(std)
126-
return mu + eps * std
127-
128-
def forward(self, x):
129-
# Encoding
130-
x = F.relu(self.enc1(x))
131-
x = self.pool(x)
132-
x = F.relu(self.enc2(x))
133-
x = self.pool(x)
134-
x = F.relu(self.enc3(x))
135-
x = self.pool(x)
136-
137-
x = x.view(x.size(0), -1) # Flatten
138-
139-
mu = self.fc_mu(x)
140-
log_var = self.fc_log_var(x)
141-
z = self.reparameterize(mu, log_var)
142-
143-
# Decoding
144-
x = self.decoder_input(z)
145-
x = x.view(x.size(0), 16, 8, 8) # Unflatten
146-
x = self.upsample(x)
147-
x = F.relu(self.dec1(x))
148-
x = self.upsample(x)
149-
x = F.relu(self.dec2(x))
150-
x = self.upsample(x)
151-
x = torch.sigmoid(self.dec3(x))
152-
153-
return x, mu, log_var
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class ConvolutionalAutoencoder(nn.Module):
7+
def __init__(self):
8+
super(ConvolutionalAutoencoder, self).__init__()
9+
10+
# Encoder
11+
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
12+
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
13+
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
14+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
15+
16+
# Decoder
17+
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
18+
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
19+
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
20+
21+
def forward(self, x):
22+
x, _ = self.pool(F.relu(self.enc1(x)))
23+
x, _ = self.pool(F.relu(self.enc2(x)))
24+
x, _ = self.pool(F.relu(self.enc3(x)))
25+
26+
x = F.relu(self.dec1(x))
27+
x = F.relu(self.dec2(x))
28+
x = torch.sigmoid(self.dec3(x))
29+
return x

models/convolutional_vae.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class ConvolutionalVAE(nn.Module):
7+
def __init__(self):
8+
super(ConvolutionalVAE, self).__init__()
9+
10+
# Encoder
11+
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
12+
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
13+
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
14+
self.pool = nn.MaxPool2d(2, 2)
15+
16+
self.fc_mu = nn.Linear(16 * 8 * 8, 128)
17+
self.fc_log_var = nn.Linear(16 * 8 * 8, 128)
18+
19+
# Decoder
20+
self.decoder_input = nn.Linear(128, 16 * 8 * 8)
21+
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1)
22+
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1)
23+
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1)
24+
25+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
26+
27+
def reparameterize(self, mu, log_var):
28+
std = torch.exp(0.5 * log_var)
29+
eps = torch.randn_like(std)
30+
return mu + eps * std
31+
32+
def forward(self, x):
33+
# Encoding
34+
x = F.relu(self.enc1(x))
35+
x = self.pool(x)
36+
x = F.relu(self.enc2(x))
37+
x = self.pool(x)
38+
x = F.relu(self.enc3(x))
39+
x = self.pool(x)
40+
41+
x = x.view(x.size(0), -1) # Flatten
42+
43+
mu = self.fc_mu(x)
44+
log_var = self.fc_log_var(x)
45+
z = self.reparameterize(mu, log_var)
46+
47+
# Decoding
48+
x = self.decoder_input(z)
49+
x = x.view(x.size(0), 16, 8, 8) # Unflatten
50+
x = self.upsample(x)
51+
x = F.relu(self.dec1(x))
52+
x = self.upsample(x)
53+
x = F.relu(self.dec2(x))
54+
x = self.upsample(x)
55+
x = torch.sigmoid(self.dec3(x))
56+
57+
return x, mu, log_var

models/variational_autoencoder.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class VariationalAutoencoder(nn.Module):
7+
def __init__(self, encoding_dim=128):
8+
super(VariationalAutoencoder, self).__init__()
9+
10+
# Encoder
11+
self.enc1 = nn.Linear(3 * 64 * 64, 512)
12+
self.enc2 = nn.Linear(512, 256)
13+
self.enc3 = nn.Linear(256, encoding_dim)
14+
15+
# Latent space
16+
self.fc_mu = nn.Linear(encoding_dim, encoding_dim)
17+
self.fc_log_var = nn.Linear(encoding_dim, encoding_dim)
18+
19+
# Decoder
20+
self.dec1 = nn.Linear(encoding_dim, encoding_dim)
21+
self.dec2 = nn.Linear(encoding_dim, 256)
22+
self.dec3 = nn.Linear(256, 512)
23+
self.dec4 = nn.Linear(512, 3 * 64 * 64)
24+
25+
def reparameterize(self, mu, log_var):
26+
std = torch.exp(0.5 * log_var)
27+
eps = torch.randn_like(std)
28+
return mu + eps * std
29+
30+
def forward(self, x):
31+
x = F.relu(self.enc1(x))
32+
x = F.relu(self.enc2(x))
33+
x = F.relu(self.enc3(x))
34+
35+
mu = self.fc_mu(x)
36+
log_var = self.fc_log_var(x)
37+
z = self.reparameterize(mu, log_var)
38+
39+
x = F.relu(self.dec1(z))
40+
x = F.relu(self.dec2(x))
41+
x = F.relu(self.dec3(x))
42+
x = torch.sigmoid(self.dec4(x))
43+
44+
return x, mu, log_var

0 commit comments

Comments
 (0)