Skip to content

Commit 1d7fc9d

Browse files
Update utils.py
1 parent 3e07dd9 commit 1d7fc9d

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

MRNet-Single-Model/utils/utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,49 @@ def _evaluate_model(model, val_loader, criterion, epoch, num_epochs, writer, cur
1515

1616
# Set to eval mode
1717
model.eval()
18-
18+
# List of probabilities obtained from the model
1919
y_probs = []
20+
# List of groundtruth labels
2021
y_gt = []
22+
# List of losses obtained
2123
losses = []
2224

25+
# Iterate over the validation dataset
2326
for i, (images, label) in enumerate(val_loader):
24-
27+
# If GPU is available, load the images and label
28+
# on GPU
2529
if torch.cuda.is_available():
2630
images = [image.cuda() for image in images]
2731
label = label.cuda()
2832

33+
# Obtain the model output by passing the images as input
2934
output = model(images)
30-
35+
# Evaluate the loss by comparing the output and groundtruth label
3136
loss = criterion(output, label)
32-
37+
# Add loss to the list of losses
3338
loss_value = loss.item()
3439
losses.append(loss_value)
35-
40+
# Find probability for each class by applying
41+
# sigmoid function on model output
3642
probas = torch.sigmoid(output)
37-
43+
# Add the groundtruth to the list of groundtruths
3844
y_gt.append(int(label.item()))
45+
# Add predicted probability to the list
3946
y_probs.append(probas.item())
4047

4148
try:
49+
# Evaluate area under ROC curve based on the groundtruth label
50+
# and predicted probability
4251
auc = metrics.roc_auc_score(y_gt, y_probs)
4352
except:
53+
# Default area under ROC curve
4454
auc = 0.5
45-
55+
# Add information to the writer about validation loss and Area under ROC curve
4656
writer.add_scalar('Val/Loss', loss_value, epoch * len(val_loader) + i)
4757
writer.add_scalar('Val/AUC', auc, epoch * len(val_loader) + i)
4858

4959
if (i % log_every == 0) & (i > 0):
60+
# Display the information about average validation loss and area under ROC curve
5061
print('''[Epoch: {0} / {1} | Batch : {2} / {3} ]| Avg Val Loss {4} | Val AUC : {5} | lr : {6}'''.
5162
format(
5263
epoch + 1,
@@ -58,9 +69,9 @@ def _evaluate_model(model, val_loader, criterion, epoch, num_epochs, writer, cur
5869
current_lr
5970
)
6071
)
61-
72+
# Add information to the writer about total epochs and Area under ROC curve
6273
writer.add_scalar('Val/AUC_epoch', auc, epoch + i)
63-
74+
# Find mean area under ROC curve and validation loss
6475
val_loss_epoch = np.round(np.mean(losses), 4)
6576
val_auc_epoch = np.round(auc, 4)
6677

@@ -71,41 +82,62 @@ def _train_model(model, train_loader, epoch, num_epochs, optimizer, criterion, w
7182
# Set to train mode
7283
model.train()
7384

85+
# Initialize the predicted probabilities
7486
y_probs = []
87+
# Initialize the groundtruth labels
7588
y_gt = []
89+
# Initialize the loss between the groundtruth label
90+
# and the predicted probability
7691
losses = []
7792

93+
# Iterate over the training dataset
7894
for i, (images, label) in enumerate(train_loader):
95+
# Reset the gradient by zeroing it
7996
optimizer.zero_grad()
80-
97+
98+
# If GPU is available, transfer the images and label
99+
# to the GPU
81100
if torch.cuda.is_available():
82101
images = [image.cuda() for image in images]
83102
label = label.cuda()
84103

104+
# Obtain the prediction using the model
85105
output = model(images)
86106

107+
# Evaluate the loss by comparing the prediction
108+
# and groundtruth label
87109
loss = criterion(output, label)
110+
# Perform a backward propagation
88111
loss.backward()
112+
# Modify the weights based on the error gradient
89113
optimizer.step()
90114

115+
# Add current loss to the list of losses
91116
loss_value = loss.item()
92117
losses.append(loss_value)
93118

119+
# Find probabilities from output using sigmoid function
94120
probas = torch.sigmoid(output)
95121

122+
# Add current groundtruth label to the list of groundtruths
96123
y_gt.append(int(label.item()))
124+
# Add current probabilities to the list of probabilities
97125
y_probs.append(probas.item())
98126

99127
try:
128+
# Try finding the area under ROC curve
100129
auc = metrics.roc_auc_score(y_gt, y_probs)
101130
except:
131+
# Use default value of area under ROC curve as 0.5
102132
auc = 0.5
103-
133+
134+
# Add information to the writer about training loss and Area under ROC curve
104135
writer.add_scalar('Train/Loss', loss_value,
105136
epoch * len(train_loader) + i)
106137
writer.add_scalar('Train/AUC', auc, epoch * len(train_loader) + i)
107138

108139
if (i % log_every == 0) & (i > 0):
140+
# Display the information about average training loss and area under ROC curve
109141
print('''[Epoch: {0} / {1} | Batch : {2} / {3} ]| Avg Train Loss {4} | Train AUC : {5} | lr : {6}'''.
110142
format(
111143
epoch + 1,
@@ -117,9 +149,10 @@ def _train_model(model, train_loader, epoch, num_epochs, optimizer, criterion, w
117149
current_lr
118150
)
119151
)
120-
152+
# Add information to the writer about total epochs and Area under ROC curve
121153
writer.add_scalar('Train/AUC_epoch', auc, epoch + i)
122154

155+
# Find mean area under ROC curve and training loss
123156
train_loss_epoch = np.round(np.mean(losses), 4)
124157
train_auc_epoch = np.round(auc, 4)
125158

0 commit comments

Comments
 (0)