Skip to content

Commit 13eb9e0

Browse files
feat: improvements on training and create new models and tools file
1 parent 4410e0b commit 13eb9e0

File tree

12 files changed

+283
-26
lines changed

12 files changed

+283
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__pycache__
44

55
samples
6+
params.json
67

78
*.pth
89
*.png

copy_randomic_files.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import shutil
3+
import random
4+
5+
from settings.settings import (
6+
COPY_DESTINATION_FOLDER,
7+
COPY_SOURCE_FOLDER,
8+
COPY_PERCENTAGE_TO_COPY,
9+
COPY_RANDOM_MODE,
10+
COPY_FIXED_NUMBER_TO_COPY
11+
)
12+
13+
14+
def copy_files(
15+
source_folder,
16+
destination_folder,
17+
percentage=None,
18+
fixed_number=None,
19+
random_mode=True
20+
):
21+
22+
files = [file_name for file_name in os.listdir(source_folder)
23+
if os.path.isfile(os.path.join(source_folder, file_name))]
24+
25+
if percentage is not None:
26+
total_to_copy = int(len(files) * percentage / 100)
27+
elif fixed_number is not None:
28+
total_to_copy = min(fixed_number, len(files))
29+
else:
30+
raise ValueError("Either percentage or fixed_number must be provided!")
31+
32+
if random_mode:
33+
chosen_files = random.sample(files, total_to_copy)
34+
else:
35+
chosen_files = files[:total_to_copy]
36+
37+
for file_name in chosen_files:
38+
shutil.copy2(os.path.join(source_folder, file_name), destination_folder)
39+
40+
print(f"{total_to_copy} files have been copied from {source_folder} to {destination_folder}.")
41+
42+
43+
if __name__ == '__main__':
44+
copy_files(
45+
COPY_SOURCE_FOLDER,
46+
COPY_DESTINATION_FOLDER,
47+
percentage=COPY_PERCENTAGE_TO_COPY,
48+
fixed_number=COPY_FIXED_NUMBER_TO_COPY,
49+
random_mode=COPY_RANDOM_MODE
50+
)

json/params.example.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"batch_size": 32,
3+
"resolution": 64,
4+
"encoding_dim": 16,
5+
"num_epochs": 100,
6+
"learning_rate": 0.001,
7+
"ae_type": "ae",
8+
"save_checkpoint": null
9+
}

json/params.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"batch_size": 32,
33
"resolution": 64,
44
"encoding_dim": 16,
5-
"num_epochs": 25,
6-
"learning_rate": 0.001,
7-
"ae_type": "conv_dae",
5+
"num_epochs": 10,
6+
"learning_rate": 0.0001,
7+
"ae_type": "vae",
88
"save_checkpoint": null
99
}

models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .autoencoder import Autoencoder
22
from .autoencoder_dae import DenoisingAutoencoder
3+
from .autoencoder_sparse import SparseAutoencoder
34
from .autoencoder_vae import VariationalAutoencoder
45
from .convolutional_autoencoder import ConvolutionalAutoencoder
56
from .convolutional_dae import DenoisingConvolutionalAutoencoder
67
from .convolutional_vae import ConvolutionalVAE
8+
from .convolutional_sparse import SparseConvolutionalAutoencoder

models/autoencoder_sparse.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch.nn as nn
2+
3+
4+
class SparseAutoencoder(nn.Module):
5+
def __init__(self, input_dim, encoding_dim):
6+
super(SparseAutoencoder, self).__init__()
7+
8+
self.model_structure = 'linear'
9+
self.model_variant = 'sparse'
10+
11+
self.encoder = nn.Sequential(
12+
nn.Linear(input_dim, 1024),
13+
nn.LeakyReLU(0.01),
14+
nn.Linear(1024, 512),
15+
nn.LeakyReLU(0.01),
16+
nn.Linear(512, 256),
17+
nn.LeakyReLU(0.01),
18+
nn.Linear(256, 128),
19+
nn.LeakyReLU(0.01),
20+
nn.Linear(128, 64),
21+
nn.LeakyReLU(0.01),
22+
nn.Linear(64, encoding_dim)
23+
)
24+
25+
self.decoder = nn.Sequential(
26+
nn.Linear(encoding_dim, 64),
27+
nn.ReLU(),
28+
nn.Linear(64, 128),
29+
nn.ReLU(),
30+
nn.Linear(128, 256),
31+
nn.ReLU(),
32+
nn.Linear(256, 512),
33+
nn.ReLU(),
34+
nn.Linear(512, 1024),
35+
nn.ReLU(),
36+
nn.Linear(1024, input_dim),
37+
nn.Sigmoid()
38+
)
39+
40+
def forward(self, x):
41+
encoded = self.encoder(x)
42+
decoded = self.decoder(encoded)
43+
return decoded, encoded

models/convolutional_sparse.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class SparseConvolutionalAutoencoder(nn.Module):
7+
def __init__(self):
8+
super(SparseConvolutionalAutoencoder, self).__init__()
9+
10+
self.model_structure = 'convolutional'
11+
self.model_variant = 'sparse'
12+
13+
# Encoder
14+
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
15+
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
16+
self.enc2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
17+
self.enc3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
18+
self.enc4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
19+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
20+
21+
# Decoder
22+
self.dec0 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
23+
self.dec1 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
24+
self.dec2 = nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2)
25+
self.dec3 = nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2)
26+
self.dec4 = nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2)
27+
28+
def forward(self, x):
29+
x, _ = self.pool(F.relu(self.enc0(x)))
30+
x, _ = self.pool(F.relu(self.enc1(x)))
31+
x, _ = self.pool(F.relu(self.enc2(x)))
32+
x, _ = self.pool(F.relu(self.enc3(x)))
33+
encoded = self.enc4(x)
34+
x, _ = self.pool(F.relu(encoded))
35+
36+
x = F.relu(self.dec0(x))
37+
x = F.relu(self.dec1(x))
38+
x = F.relu(self.dec2(x))
39+
x = F.relu(self.dec3(x))
40+
x = torch.sigmoid(self.dec4(x))
41+
return x, encoded

reset_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import shutil
3+
4+
def reset_dataset():
5+
dataset_path = "./dataset"
6+
train_path = os.path.join(dataset_path, "train")
7+
valid_path = os.path.join(dataset_path, "valid")
8+
9+
# Remove as pastas se existirem
10+
if os.path.exists(train_path):
11+
shutil.rmtree(train_path)
12+
if os.path.exists(valid_path):
13+
shutil.rmtree(valid_path)
14+
15+
# Recria as pastas
16+
os.makedirs(train_path)
17+
os.makedirs(valid_path)
18+
19+
print("O dataset foi resetado com sucesso!")
20+
21+
if __name__ == "__main__":
22+
reset_dataset()

run.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
ConvolutionalAutoencoder,
1313
ConvolutionalVAE,
1414
DenoisingAutoencoder,
15+
SparseAutoencoder,
1516
VariationalAutoencoder,
1617
DenoisingConvolutionalAutoencoder,
18+
SparseConvolutionalAutoencoder
1719
)
1820

1921
from settings import settings
@@ -26,12 +28,14 @@ def get_model_by_type(ae_type=None, input_dim=None, encoding_dim=None, device=No
2628
models = {
2729
'ae': lambda: Autoencoder(input_dim, encoding_dim),
2830
'dae': lambda: DenoisingAutoencoder(input_dim, encoding_dim),
31+
'sparse': lambda: SparseAutoencoder(input_dim, encoding_dim),
2932
'vae': VariationalAutoencoder,
3033
'conv': ConvolutionalAutoencoder,
3134
'conv_dae': DenoisingConvolutionalAutoencoder,
3235
'conv_vae': ConvolutionalVAE,
36+
'conv_sparse': SparseConvolutionalAutoencoder,
3337
}
34-
38+
3539
if ae_type is None:
3640
return list(models.keys())
3741

@@ -81,15 +85,15 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
8185
model = get_model_by_type(ae_type, input_dim, encoding_dim, device)
8286
optimizer = torch.optim.Adam(model.parameters())
8387

84-
start_epoch = 0
85-
if os.path.exists(settings.PATH_SAVED_MODEL):
86-
model, optimizer, start_epoch = load_checkpoint(
87-
model, optimizer, settings.PATH_SAVED_MODEL, device
88-
)
89-
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
90-
9188
try:
9289
if not load_trained_model:
90+
start_epoch = 1
91+
if os.path.exists(settings.PATH_SAVED_MODEL):
92+
model, optimizer, start_epoch = load_checkpoint(
93+
model, optimizer, settings.PATH_SAVED_MODEL, device
94+
)
95+
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
96+
9397
start_time = time.time()
9498

9599
train_autoencoder(
@@ -100,7 +104,8 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
100104
device=device,
101105
start_epoch=start_epoch,
102106
optimizer=optimizer,
103-
save_checkpoint=save_checkpoint
107+
save_checkpoint=save_checkpoint,
108+
ae_type=ae_type
104109
)
105110

106111
elapsed_time = utils.format_time(time.time() - start_time)
@@ -112,12 +117,12 @@ def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
112117

113118
if not test_mode:
114119
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
115-
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
120+
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device)
116121
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
117122

118123
visualize_reconstructions(
119124
model, valid_dataloader, num_samples=10,
120-
device=device, ae_type=ae_type, resolution=resolution
125+
device=device, resolution=resolution
121126
)
122127

123128

settings/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)