22import torch
33import torch .optim as optim
44import torch .nn as nn
5- from torchvision import transforms
65from torchvision .utils import save_image , make_grid
76import matplotlib .pyplot as plt
8- from PIL import Image
97
108
11- def train_autoencoder (model , dataloader , num_epochs = 5 , learning_rate = 0.001 , device = 'cpu' ):
9+ def train_autoencoder (model , dataloader , num_epochs = 5 , learning_rate = 0.001 , device = 'cpu' , start_epoch = 0 , optimizer = None ):
1210 criterion = nn .MSELoss ()
13- optimizer = optim .Adam (model .parameters (), lr = learning_rate )
11+ if optimizer is None :
12+ optimizer = optim .Adam (model .parameters (), lr = learning_rate )
1413
15- for epoch in range (num_epochs ):
14+ for epoch in range (start_epoch , num_epochs ):
1615 for data in dataloader :
1716 img = data .to (device )
1817 img = img .view (img .size (0 ), - 1 )
@@ -24,6 +23,7 @@ def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, devi
2423 optimizer .step ()
2524
2625 print (f'Epoch [{ epoch + 1 } /{ num_epochs } ], Loss: { loss .item ():.4f} ' )
26+ save_checkpoint (model , optimizer , epoch , './autoencoder_checkpoint.pth' )
2727
2828 return model
2929
@@ -38,11 +38,9 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
3838 samples = samples .view (- 1 , 3 , 64 , 64 )
3939 reconstructions = reconstructions .view (- 1 , 3 , 64 , 64 )
4040
41- # Combine as amostras e reconstruções em uma única grade
4241 combined = torch .cat ([samples , reconstructions ], dim = 0 )
4342 grid_img = make_grid (combined , nrow = num_samples )
4443
45- # Visualização usando Matplotlib
4644 plt .imshow (grid_img .permute (1 , 2 , 0 ).cpu ().detach ().numpy ())
4745 plt .axis ('off' )
4846 plt .show ()
@@ -62,6 +60,23 @@ def load_model(model, path, device):
6260 return model
6361
6462
63+ def save_checkpoint (model , optimizer , epoch , path ):
64+ checkpoint = {
65+ 'epoch' : epoch ,
66+ 'model_state_dict' : model .state_dict (),
67+ 'optimizer_state_dict' : optimizer .state_dict (),
68+ }
69+ torch .save (checkpoint , path )
70+
71+
72+ def load_checkpoint (model , optimizer , path , device ):
73+ checkpoint = torch .load (path , map_location = device )
74+ model .load_state_dict (checkpoint ['model_state_dict' ])
75+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
76+ epoch = checkpoint ['epoch' ]
77+ return model , optimizer , epoch + 1
78+
79+
6580def evaluate_autoencoder (model , dataloader , device ):
6681 model .eval ()
6782 total_loss = 0
0 commit comments