-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
123 lines (92 loc) · 3.62 KB
/
eval.py
File metadata and controls
123 lines (92 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from ImageDataset import *
from torch import nn
from torchmetrics import StructuralSimilarityIndexMeasure
from main import UNet
import csv
import torch
# +++++ ===== +++++ ===== +++++ ===== +++++ =====
#test_path = "dataset/test"
eval_path = "dataset/val/"
write_file_path = "examples/write_file.csv"
batch_size = 64
num_models = 8
eval_size = 64
# +++++ ===== +++++ ===== +++++ ===== +++++ =====
def main():
#define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# define criterion
criterion = nn.MSELoss()
ssim = StructuralSimilarityIndexMeasure().to(device)
# load image datasets
#test_loader = load_dataset(images_path=test_path, batch_size=batch_size)
eval_loader = load_dataset(eval_path, batch_size)
## evaluate for each model
# model
write_file = []
for j in range(num_models):
model = UNet()
print(f'loading model: model_epoch_{j+1}')
try:
weights = torch.load(f'output/model_epoch_{j+1}.pth', map_location=device, weights_only=False)
except FileNotFoundError:
print(f"Model: output/model_epoch_{j+1}.pth does not exist.")
raise
total_loss = 0
model.load_state_dict(weights)#.state_dict())
model.to(device)
model.eval()
# Empty GPU mem
torch.cuda.empty_cache()
with torch.no_grad():
for i, batch in enumerate(eval_loader):
images = batch['images']
length = len(images)
device_images = images.to(device)
hints = batch['hints']
device_hints = hints.to(device)
# calculate tloss
#output = model(device_hints[:, :, :, :])
output = model(device_images[:, 0, :, :].reshape([length, 1, 224, 224]), device_hints) #batch_size, channels, h, w
#loss = criterion(output, device_images[:, :2, :, :])
loss = criterion(output, device_images[:,1:,:,:]) #+ 1.0 - ssim(output, images) #+ criterion(output[:,:2,:,:], images[:,:2,:,:])
total_loss += loss.item()
if i % batch_size == 0:
print(f'Batch [{i} / {eval_size}], Total loss: {loss.item():.4f}')
if i > eval_size:
break
print(f'Average loss: model_epoch_{j+1}: {total_loss / eval_size}')
write_file.append([f'model_epoch_{j+1}', total_loss / eval_size])
# Write data to CSV file
with open(write_file_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerows(write_file)
def graph(read_file_path):
x_data = []
y_data = []
try:
with open(read_file_path, 'r', newline = '') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if len(row) >= 2:
try:
x_data.append((row[0]))
y_data.append(float(row[1]))
except ValueError:
print(f"Value doesn't exist, skipping row {row}")
else:
print(f"Skipping row: {row}")
except FileNotFoundError:
print(f'File not found: {read_file_path}')
except Exception as e:
print(f'An error occured: {e}')
import matplotlib.pyplot as plt
plt.plot(x_data, y_data, 'r')
plt.plot(x_data, y_data, 'o')
plt.xlabel('Model epoch')
plt.ylabel('MSE error')
plt.title('Error vs. Epoch')
plt.ylim(0, 0.5)
plt.savefig("examples/eval_fig.png")
main()
graph(write_file_path)