Skip to content

Commit 7bafb58

Browse files
feat: create denoising autoencoder model and refactory run.py to create smart function of model selection
1 parent 51b36e2 commit 7bafb58

File tree

4 files changed

+75
-15
lines changed

4 files changed

+75
-15
lines changed

json/params.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"batch_size": 32,
33
"resolution": 64,
44
"encoding_dim": 16,
5-
"num_epochs": 1000,
5+
"num_epochs": 25,
66
"learning_rate": 0.001,
7-
"ae_type": "conv",
7+
"ae_type": "ae",
88
"save_checkpoint": null
99
}

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .autoencoder import Autoencoder
22
from .convolutional_autoencoder import ConvolutionalAutoencoder
33
from .convolutional_vae import ConvolutionalVAE
4+
from .denoising_autoencoder import DenoisingAutoencoder
45
from .variational_autoencoder import VariationalAutoencoder

models/denoising_autoencoder.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class DenoisingAutoencoder(nn.Module):
6+
def __init__(self, input_dim, encoding_dim):
7+
print('***** Denoising Autoencoder input_dim:', input_dim)
8+
super(DenoisingAutoencoder, self).__init__()
9+
10+
self.encoder = nn.Sequential(
11+
nn.Linear(input_dim, 1024),
12+
nn.ReLU(),
13+
nn.Linear(1024, 512),
14+
nn.ReLU(),
15+
nn.Linear(512, 256),
16+
nn.ReLU(),
17+
nn.Linear(256, 128),
18+
nn.ReLU(),
19+
nn.Linear(128, 64),
20+
nn.ReLU(),
21+
nn.Linear(64, encoding_dim)
22+
)
23+
24+
self.decoder = nn.Sequential(
25+
nn.Linear(encoding_dim, 64),
26+
nn.ReLU(),
27+
nn.Linear(64, 128),
28+
nn.ReLU(),
29+
nn.Linear(128, 256),
30+
nn.ReLU(),
31+
nn.Linear(256, 512),
32+
nn.ReLU(),
33+
nn.Linear(512, 1024),
34+
nn.ReLU(),
35+
nn.Linear(1024, input_dim),
36+
nn.Sigmoid()
37+
)
38+
39+
def forward(self, x):
40+
noise = torch.randn_like(x) * 0.1
41+
x_corrupted = x + noise
42+
43+
x_encoded = self.encoder(x_corrupted)
44+
x_decoded = self.decoder(x_encoded)
45+
46+
return x_decoded

run.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,37 @@
66
import numpy as np
77
import torch
88

9-
from models import Autoencoder, ConvolutionalAutoencoder, ConvolutionalVAE, VariationalAutoencoder
9+
from models import (
10+
Autoencoder,
11+
ConvolutionalAutoencoder,
12+
ConvolutionalVAE,
13+
DenoisingAutoencoder,
14+
VariationalAutoencoder,
15+
)
16+
1017
from settings import settings
1118
from utils.dataloader import get_dataloader
1219
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
1320
from utils import utils
1421

1522

16-
def set_seed(seed=42):
23+
def get_model_by_type(ae_type, input_dim, encoding_dim, device):
24+
models = {
25+
'ae': lambda: Autoencoder(input_dim, encoding_dim),
26+
'conv': ConvolutionalAutoencoder,
27+
'conv_vae': ConvolutionalVAE,
28+
'dae': lambda: DenoisingAutoencoder(input_dim, encoding_dim),
29+
'vae': VariationalAutoencoder,
30+
}
31+
32+
if ae_type not in models:
33+
raise ValueError(f"Unknown AE type: {ae_type}")
34+
35+
model = models[ae_type]()
36+
return model.to(device)
37+
38+
39+
def set_seed(seed):
1740
torch.manual_seed(seed)
1841
torch.cuda.manual_seed_all(seed)
1942
np.random.seed(seed)
@@ -46,17 +69,7 @@ def main(load_trained_model):
4669
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4770
dataloader = get_dataloader(settings.DATA_PATH, batch_size, resolution)
4871

49-
if ae_type == 'ae':
50-
model = Autoencoder(input_dim, encoding_dim).to(device)
51-
elif ae_type == 'conv':
52-
model = ConvolutionalAutoencoder().to(device)
53-
elif ae_type == 'vae':
54-
model = VariationalAutoencoder().to(device)
55-
elif ae_type == 'conv_vae':
56-
model = ConvolutionalVAE().to(device)
57-
else:
58-
raise ValueError(f"Unknown AE type: {ae_type}")
59-
72+
model = get_model_by_type(ae_type, input_dim, encoding_dim, device)
6073
optimizer = torch.optim.Adam(model.parameters())
6174

6275
start_epoch = 0

0 commit comments

Comments
 (0)