Skip to content

Commit 84cd3e1

Browse files
Merge pull request #7 from renan-siqueira/feature/Improvements
Checkpoint Feature
2 parents b225ba1 + d3f7df4 commit 84cd3e1

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

run.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,37 @@
33

44
from models.autoencoder import Autoencoder
55
from utils.dataloader import get_dataloader
6-
from utils.trainer import train_autoencoder, visualize_reconstructions, save_model, load_model, evaluate_autoencoder
6+
from utils.trainer import train_autoencoder, visualize_reconstructions, save_checkpoint, load_checkpoint, evaluate_autoencoder
77
from settings import settings
88

99

1010
def main(load_trained_model):
1111
BATCH_SIZE = 32
1212
INPUT_DIM = 3 * 64 * 64
13-
ENCODING_DIM = 12
14-
NUM_EPOCHS = 1000
13+
ENCODING_DIM = 64
14+
NUM_EPOCHS = 200
1515

1616
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17-
1817
dataloader = get_dataloader(settings.DATA_PATH, BATCH_SIZE)
19-
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
2018

21-
if load_trained_model:
22-
trained_model = load_model(model, settings.PATH_SAVED_MODEL, device=device)
23-
else:
24-
trained_model = train_autoencoder(model, dataloader, NUM_EPOCHS, device=device)
19+
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
20+
optimizer = torch.optim.Adam(model.parameters())
2521

26-
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)
22+
start_epoch = 0
23+
if os.path.exists(settings.PATH_SAVED_MODEL):
24+
model, optimizer, start_epoch = load_checkpoint(model, optimizer, settings.PATH_SAVED_MODEL, device)
25+
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
2726

28-
save_path = os.path.join('./', settings.PATH_SAVED_MODEL)
29-
save_model(trained_model, save_path)
30-
print(f"Model saved to {save_path}")
27+
if not load_trained_model:
28+
for epoch in range(start_epoch, NUM_EPOCHS):
29+
train_autoencoder(model, dataloader, device=device)
30+
print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] complete!")
31+
save_checkpoint(model, optimizer, epoch, settings.PATH_SAVED_MODEL)
3132

32-
avg_valid_loss = evaluate_autoencoder(trained_model, valid_dataloader, device)
33+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)
34+
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device)
3335
print(f"Average validation loss: {avg_valid_loss:.4f}")
34-
35-
visualize_reconstructions(trained_model, valid_dataloader, num_samples=10, device=device)
36+
visualize_reconstructions(model, valid_dataloader, num_samples=10, device=device)
3637

3738

3839
if __name__ == "__main__":

utils/trainer.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
import torch
33
import torch.optim as optim
44
import torch.nn as nn
5-
from torchvision import transforms
65
from torchvision.utils import save_image, make_grid
76
import matplotlib.pyplot as plt
8-
from PIL import Image
97

108

11-
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu'):
9+
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu', start_epoch=0, optimizer=None):
1210
criterion = nn.MSELoss()
13-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
11+
if optimizer is None:
12+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
1413

15-
for epoch in range(num_epochs):
14+
for epoch in range(start_epoch, num_epochs):
1615
for data in dataloader:
1716
img = data.to(device)
1817
img = img.view(img.size(0), -1)
@@ -24,6 +23,7 @@ def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, devi
2423
optimizer.step()
2524

2625
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
26+
save_checkpoint(model, optimizer, epoch, './autoencoder_checkpoint.pth')
2727

2828
return model
2929

@@ -38,11 +38,9 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
3838
samples = samples.view(-1, 3, 64, 64)
3939
reconstructions = reconstructions.view(-1, 3, 64, 64)
4040

41-
# Combine as amostras e reconstruções em uma única grade
4241
combined = torch.cat([samples, reconstructions], dim=0)
4342
grid_img = make_grid(combined, nrow=num_samples)
4443

45-
# Visualização usando Matplotlib
4644
plt.imshow(grid_img.permute(1, 2, 0).cpu().detach().numpy())
4745
plt.axis('off')
4846
plt.show()
@@ -62,6 +60,23 @@ def load_model(model, path, device):
6260
return model
6361

6462

63+
def save_checkpoint(model, optimizer, epoch, path):
64+
checkpoint = {
65+
'epoch': epoch,
66+
'model_state_dict': model.state_dict(),
67+
'optimizer_state_dict': optimizer.state_dict(),
68+
}
69+
torch.save(checkpoint, path)
70+
71+
72+
def load_checkpoint(model, optimizer, path, device):
73+
checkpoint = torch.load(path, map_location=device)
74+
model.load_state_dict(checkpoint['model_state_dict'])
75+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
76+
epoch = checkpoint['epoch']
77+
return model, optimizer, epoch + 1
78+
79+
6580
def evaluate_autoencoder(model, dataloader, device):
6681
model.eval()
6782
total_loss = 0

0 commit comments

Comments
 (0)