-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
107 lines (80 loc) · 3.36 KB
/
predict.py
File metadata and controls
107 lines (80 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import matplotlib.pyplot as plt
import model as md
import prepare_data as data
model_name = 'model_0.9798385305191156.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = md.get_model()
model.load_state_dict(torch.load('./models/' + model_name))
model.to(device)
model.eval()
# Préparation des données pour la prédiction
_, valid_set = data.get_data()
valid_loader = data.get_validloader(valid_set)
def predict(max_samples=100):
nb_error = 0
nb_correct = 0
images_valid_list = []
predicted_valid_list = []
labels_valid_list = []
images_error_list = []
predicted_error_list = []
labels_error_list = []
with torch.no_grad():
for i, (images, labels) in enumerate(valid_loader):
if i > max_samples:
break
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
if predicted != labels:
nb_error += 1
images_error_list.append(images[0].cpu())
predicted_error_list.append(predicted.item())
labels_error_list.append(labels.item())
else:
nb_correct += 1
images_valid_list.append(images[0].cpu())
predicted_valid_list.append(predicted.item())
labels_valid_list.append(labels.item())
accuracy = nb_correct / (nb_correct + nb_error)
print(f"Number of errors: {nb_error}")
print(f"Number of correct: {nb_correct}")
print(f"Accuracy: {accuracy:.5f}")
return (images_valid_list, predicted_valid_list, labels_valid_list,
images_error_list, predicted_error_list, labels_error_list)
def imshow(img, predicted, label):
img = img.cpu().numpy().squeeze()
plt.imshow(img, cmap='gray')
plt.title(f'Predicted: {predicted}, Label: {label}')
plt.show()
def plot_images(images, predicted, labels, images_error, predicted_error, labels_error):
n_valid = len(images)
n_error = len(images_error)
cols = 5
rows_valid = (n_valid + cols - 1) // cols # Calcul du nombre de lignes nécessaires
rows_error = (n_error + cols - 1) // cols
plt.figure(figsize=(15, 3 * rows_valid))
for i in range(n_valid):
plt.subplot(rows_valid, cols, i + 1)
img = images[i].numpy().squeeze()
plt.imshow(img, cmap='gray')
plt.title(f'Pred: {predicted[i]}, Label: {labels[i]}')
plt.axis('off')
plt.savefig(f'./output/predictions_valid_{model_name}.png') # Sauvegarder l'image au lieu de l'afficher
if len(images_error) > 0:
plt.figure(figsize=(15, 3 * rows_error))
for i in range(n_error):
plt.subplot(rows_error, cols, i + 1)
img = images_error[i].numpy().squeeze()
plt.imshow(img, cmap='gray')
plt.title(f'Pred: {predicted_error[i]}, Label: {labels_error[i]}')
plt.axis('off')
plt.savefig(f'./output/predictions_errors_{model_name}.png')
else :
print("No images error")
plt.close()
(images_valid_list, predicted_valid_list, labels_valid_list,
images_error_list, predicted_error_list, labels_error_list) = predict()
plot_images(images_valid_list, predicted_valid_list, labels_valid_list,
images_error_list, predicted_error_list, labels_error_list)