Skip to content

Commit 8fb66eb

Browse files
feat: upgrade process and adapted utils
1 parent a3cc936 commit 8fb66eb

File tree

5 files changed

+100
-59
lines changed

5 files changed

+100
-59
lines changed

json/params.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
{
22
"batch_size": 32,
33
"resolution": 64,
4-
"encoding_dim": 128,
5-
"num_epochs": 500,
4+
"encoding_dim": 16,
5+
"num_epochs": 1000,
66
"learning_rate": 0.001,
7-
"ae_type": "conv_vae"
7+
"ae_type": "ae",
8+
"save_checkpoint": null
89
}

run.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
import os
22
import json
3+
import random
4+
import time
35

6+
import numpy as np
47
import torch
58

69
from models import Autoencoder, ConvolutionalAutoencoder, ConvolutionalVAE, VariationalAutoencoder
10+
from settings import settings
711
from utils.dataloader import get_dataloader
812
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
9-
from settings import settings
13+
from utils import utils
14+
15+
16+
def set_seed(seed=42):
17+
torch.manual_seed(seed)
18+
torch.cuda.manual_seed_all(seed)
19+
np.random.seed(seed)
20+
random.seed(seed)
21+
torch.backends.cudnn.deterministic = True
22+
torch.backends.cudnn.benchmark = False
1023

1124

1225
def load_params(path):
@@ -16,6 +29,7 @@ def load_params(path):
1629

1730

1831
def main(load_trained_model):
32+
set_seed(1)
1933
params = load_params(settings.PATH_PARAMS_JSON)
2034

2135
batch_size = params["batch_size"]
@@ -24,12 +38,13 @@ def main(load_trained_model):
2438
num_epochs = params["num_epochs"]
2539
learning_rate = params.get("learning_rate", 0.001)
2640
ae_type = params["ae_type"]
41+
save_checkpoint = params["save_checkpoint"]
2742

2843
# Calculate input_dim based on resolution
2944
input_dim = 3 * resolution * resolution
3045

3146
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32-
dataloader = get_dataloader(settings.DATA_PATH, batch_size)
47+
dataloader = get_dataloader(settings.DATA_PATH, batch_size, resolution)
3348

3449
if ae_type == 'ae':
3550
model = Autoencoder(input_dim, encoding_dim).to(device)
@@ -51,25 +66,35 @@ def main(load_trained_model):
5166
)
5267
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
5368

54-
if not load_trained_model:
55-
train_autoencoder(
56-
model,
57-
dataloader,
58-
num_epochs=num_epochs,
59-
learning_rate=learning_rate,
60-
device=device,
61-
start_epoch=start_epoch,
62-
optimizer=optimizer,
63-
ae_type=ae_type
64-
)
65-
print(f"Training complete up to epoch {num_epochs}!")
66-
67-
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size)
69+
try:
70+
if not load_trained_model:
71+
start_time = time.time()
72+
73+
train_autoencoder(
74+
model,
75+
dataloader,
76+
num_epochs=num_epochs,
77+
learning_rate=learning_rate,
78+
device=device,
79+
start_epoch=start_epoch,
80+
optimizer=optimizer,
81+
ae_type=ae_type,
82+
save_checkpoint=save_checkpoint
83+
)
84+
85+
elapsed_time = utils.format_time(time.time() - start_time)
86+
print(f"\nTraining took {elapsed_time}")
87+
print(f"Training complete up to epoch {num_epochs}!")
88+
89+
except KeyboardInterrupt:
90+
print("\nTraining interrupted by user.")
91+
92+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
6893
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
69-
print(f"Average validation loss: {avg_valid_loss:.4f}")
94+
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
7095

7196
visualize_reconstructions(
72-
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type
97+
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type, resolution=resolution
7398
)
7499

75100

utils/dataloader.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import os
2-
import torch
3-
from torchvision import datasets, transforms
42
from torchvision.transforms import ToTensor, Resize, Compose
53
from torch.utils.data import DataLoader, Dataset
64
from PIL import Image
75

86

9-
def get_dataloader(data_path, batch_size):
10-
dataset = CustomDataset(data_path)
7+
def get_dataloader(data_path, batch_size, resolution):
8+
dataset = CustomDataset(data_path, resolution)
119

1210
dataloader = DataLoader(
1311
dataset,
@@ -19,12 +17,12 @@ def get_dataloader(data_path, batch_size):
1917

2018

2119
class CustomDataset(Dataset):
22-
def __init__(self, data_path):
20+
def __init__(self, data_path, resolution):
2321
self.data_path = data_path
2422
self.image_files = os.listdir(data_path)
2523

2624
self.transforms = Compose([
27-
Resize((64, 64)),
25+
Resize((resolution, resolution)),
2826
ToTensor()
2927
])
3028

utils/trainer.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,37 @@
77
import matplotlib.pyplot as plt
88

99

10-
def train_autoencoder(model, dataloader, num_epochs, learning_rate, device, start_epoch, optimizer, ae_type):
10+
def save_model(model, path):
11+
torch.save(model.state_dict(), path)
12+
13+
14+
def load_model(model, path, device):
15+
model.load_state_dict(torch.load(path, map_location=device))
16+
model.eval()
17+
return model
18+
19+
20+
def save_checkpoint_file(model, optimizer, epoch, path):
21+
checkpoint = {
22+
'epoch': epoch,
23+
'model_state_dict': model.state_dict(),
24+
'optimizer_state_dict': optimizer.state_dict(),
25+
}
26+
torch.save(checkpoint, path)
27+
28+
29+
def load_checkpoint(model, optimizer, path, device):
30+
checkpoint = torch.load(path, map_location=device)
31+
model.load_state_dict(checkpoint['model_state_dict'])
32+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
33+
epoch = checkpoint['epoch']
34+
return model, optimizer, epoch + 1
35+
36+
37+
def train_autoencoder(
38+
model, dataloader, num_epochs, learning_rate, device,
39+
start_epoch, optimizer, ae_type, save_checkpoint
40+
):
1141
criterion = nn.MSELoss()
1242
if optimizer is None:
1343
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
@@ -31,7 +61,8 @@ def train_autoencoder(model, dataloader, num_epochs, learning_rate, device, star
3161
optimizer.step()
3262

3363
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
34-
save_checkpoint(model, optimizer, epoch, './autoencoder_checkpoint.pth')
64+
if save_checkpoint:
65+
save_checkpoint_file(model, optimizer, epoch, './autoencoder_checkpoint.pth')
3566

3667
return model
3768

@@ -63,7 +94,7 @@ def evaluate_autoencoder(model, dataloader, device, ae_type):
6394
return total_loss / len(dataloader)
6495

6596

66-
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples", ae_type='ae'):
97+
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples", ae_type='ae', resolution=64):
6798
model.eval()
6899
samples = next(iter(dataloader))
69100
samples = samples[:num_samples].to(device)
@@ -76,8 +107,8 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
76107
else:
77108
reconstructions = model(samples)
78109

79-
samples = samples.view(-1, 3, 64, 64)
80-
reconstructions = reconstructions.view(-1, 3, 64, 64)
110+
samples = samples.view(-1, 3, resolution, resolution)
111+
reconstructions = reconstructions.view(-1, 3, resolution, resolution)
81112

82113
combined = torch.cat([samples, reconstructions], dim=0)
83114
grid_img = make_grid(combined, nrow=num_samples)
@@ -89,30 +120,3 @@ def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', s
89120
if not os.path.exists(save_path):
90121
os.makedirs(save_path)
91122
save_image(grid_img, os.path.join(save_path, 'combined_samples.png'))
92-
93-
94-
def save_model(model, path):
95-
torch.save(model.state_dict(), path)
96-
97-
98-
def load_model(model, path, device):
99-
model.load_state_dict(torch.load(path, map_location=device))
100-
model.eval()
101-
return model
102-
103-
104-
def save_checkpoint(model, optimizer, epoch, path):
105-
checkpoint = {
106-
'epoch': epoch,
107-
'model_state_dict': model.state_dict(),
108-
'optimizer_state_dict': optimizer.state_dict(),
109-
}
110-
torch.save(checkpoint, path)
111-
112-
113-
def load_checkpoint(model, optimizer, path, device):
114-
checkpoint = torch.load(path, map_location=device)
115-
model.load_state_dict(checkpoint['model_state_dict'])
116-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
117-
epoch = checkpoint['epoch']
118-
return model, optimizer, epoch + 1

utils/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
def format_time(elapsed_time):
2+
if elapsed_time < 60:
3+
return f"{elapsed_time:.2f} seconds"
4+
elif elapsed_time < 3600:
5+
minutes = elapsed_time // 60
6+
seconds = elapsed_time % 60
7+
return f"{minutes:.0f} minutes, {seconds:.2f} seconds"
8+
else:
9+
hours = elapsed_time // 3600
10+
remainder = elapsed_time % 3600
11+
minutes = remainder // 60
12+
seconds = remainder % 60
13+
return f"{hours:.0f} hours, {minutes:.0f} minutes, {seconds:.2f} seconds"

0 commit comments

Comments
 (0)