Skip to content

Commit 473c721

Browse files
Merge pull request #501 from lipi17dpatnaik/patch-1
Add additional comments
2 parents ef7ef28 + 1d7fc9d commit 473c721

File tree

3 files changed

+118
-32
lines changed

3 files changed

+118
-32
lines changed

MRNet-Single-Model/dataset/dataset.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,52 +19,58 @@ class MRData():
1919

2020
def __init__(self,task = 'acl', train = True, transform = None, weights = None):
2121
"""Initialize the dataset
22-
2322
Args:
2423
plane : along which plane to load the data
2524
task : for which task to load the labels
2625
train : whether to load the train or val data
2726
transform : which transforms to apply
2827
weights (Tensor) : Give wieghted loss to postive class eg. `weights=torch.tensor([2.223])`
2928
"""
29+
# Define the three planes to use
3030
self.planes=['axial', 'coronal', 'sagittal']
31+
# Initialize the records as None
3132
self.records = None
3233
# an empty dictionary
3334
self.image_path={}
3435

36+
# If we are in training loop
3537
if train:
38+
# Read data about patient records
3639
self.records = pd.read_csv('./images/train-{}.csv'.format(task),header=None, names=['id', 'label'])
3740

38-
'''
39-
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
40-
image for that plane}
41-
'''
4241
for plane in self.planes:
42+
# For each plane, specify the image path
4343
self.image_path[plane] = './images/train/{}/'.format(plane)
4444
else:
45+
# If we are in testing loop
46+
# don't use any transformation
4547
transform = None
48+
# Read testing/validation data (patients records)
4649
self.records = pd.read_csv('./images/valid-{}.csv'.format(task),header=None, names=['id', 'label'])
47-
'''
48-
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
49-
image for that plane}
50-
'''
50+
5151
for plane in self.planes:
52+
# Read path of images for each plane
5253
self.image_path[plane] = './images/valid/{}/'.format(plane)
5354

54-
55+
# Initialize the transformation to apply on images
5556
self.transform = transform
5657

58+
# Append 0s to the patient record id
5759
self.records['id'] = self.records['id'].map(
5860
lambda i: '0' * (4 - len(str(i))) + str(i))
5961
# empty dictionary
60-
self.paths={}
62+
self.paths={}
6163
for plane in self.planes:
64+
# Get paths of numpy data files for each plane
6265
self.paths[plane] = [self.image_path[plane] + filename +
6366
'.npy' for filename in self.records['id'].tolist()]
6467

68+
# Convert labels from Pandas Series to a list
6569
self.labels = self.records['label'].tolist()
6670

71+
# Total positive cases
6772
pos = sum(self.labels)
73+
# Total negative cases
6874
neg = len(self.labels) - pos
6975

7076
# Find the wieghts of pos and neg classes
@@ -90,53 +96,75 @@ def __getitem__(self, index):
9096
img_raw = {}
9197

9298
for plane in self.planes:
99+
# Load raw image data for each plane
93100
img_raw[plane] = np.load(self.paths[plane][index])
101+
# Resize the image loaded in the previous step
94102
img_raw[plane] = self._resize_image(img_raw[plane])
95103

96104
label = self.labels[index]
105+
# Convert label to 0 and 1
97106
if label == 1:
98107
label = torch.FloatTensor([1])
99108
elif label == 0:
100109
label = torch.FloatTensor([0])
101110

111+
# Return a list of three images for three planes and the label of the record
102112
return [img_raw[plane] for plane in self.planes], label
103113

104114
def _resize_image(self, image):
105115
"""Resize the image to `(3,224,224)` and apply
106116
transforms if possible.
107117
"""
108118
# Resize the image
119+
# Calculate extra padding present in the image
120+
# which needs to be removed
109121
pad = int((image.shape[2] - INPUT_DIM)/2)
122+
# This is equivalent to center cropping the image
110123
image = image[:,pad:-pad,pad:-pad]
124+
# Normalize the image by subtracting it by mean and dividing by standard
125+
# deviation
111126
image = (image-np.min(image))/(np.max(image)-np.min(image))*MAX_PIXEL_VAL
112127
image = (image - MEAN) / STDDEV
113-
128+
129+
# If the transformation is not None
114130
if self.transform:
131+
# Transform the image based on the specified transformation
115132
image = self.transform(image)
116133
else:
134+
# Else, just stack the image with itself in order to match the required
135+
# dimensions
117136
image = np.stack((image,)*3, axis=1)
118-
137+
# Convert the image to a FloatTensor and return it
119138
image = torch.FloatTensor(image)
120139
return image
121140

122141
def load_data(task : str):
123142

124143
# Define the Augmentation here only
125144
augments = Compose([
145+
# Convert the image to Tensor
126146
transforms.Lambda(lambda x: torch.Tensor(x)),
147+
# Randomly rotate the image with an angle
148+
# between -25 degrees to 25 degrees
127149
RandomRotate(25),
150+
# Randomly translate the image by 11% of
151+
# image height and width
128152
RandomTranslate([0.11, 0.11]),
153+
# Randomly flip the image
129154
RandomFlip(),
155+
# Change the order of image channels
130156
transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
131157
])
132158

133159
print('Loading Train Dataset of {} task...'.format(task))
160+
# Load training dataset
134161
train_data = MRData(task, train=True, transform=augments)
135162
train_loader = data.DataLoader(
136163
train_data, batch_size=1, num_workers=11, shuffle=True
137164
)
138165

139166
print('Loading Validation Dataset of {} task...'.format(task))
167+
# Load validation dataset
140168
val_data = MRData(task, train=False)
141169
val_loader = data.DataLoader(
142170
val_data, batch_size=1, num_workers=11, shuffle=False

MRNet-Single-Model/models/MRnet.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,31 @@ class MRnet(nn.Module):
77
"""MRnet uses pretrained resnet50 as a backbone to extract features
88
"""
99

10-
def __init__(self): # add conf file
11-
10+
def __init__(self):
11+
"""This function will be used to initialize the
12+
MRnet instance."""
13+
# Initialize nn.Module instance
1214
super(MRnet,self).__init__()
1315

14-
# init three backbones for three axis
16+
# Initialize three backbones for three axis
17+
# All the three axes will use pretrained AlexNet model
18+
# The models will be used for extracting features from
19+
# the input images
1520
self.axial = models.alexnet(pretrained=True).features
1621
self.coronal = models.alexnet(pretrained=True).features
1722
self.saggital = models.alexnet(pretrained=True).features
18-
23+
24+
# Initialize 2D Adaptive Average Pooling layers
25+
# The pooling layers will reduce the size of
26+
# feature maps extracted from the previous axes
1927
self.pool_axial = nn.AdaptiveAvgPool2d(1)
2028
self.pool_coronal = nn.AdaptiveAvgPool2d(1)
2129
self.pool_saggital = nn.AdaptiveAvgPool2d(1)
22-
30+
31+
# Initialize a sequential neural network with
32+
# a single fully connected linear layer
33+
# The network will output the probability of
34+
# having a particular disease
2335
self.fc = nn.Sequential(
2436
nn.Linear(in_features=3*256,out_features=1)
2537
)
@@ -33,24 +45,37 @@ def forward(self,x):
3345
# squeeze the first dimension as there
3446
# is only one patient in each batch
3547
images = [torch.squeeze(img, dim=0) for img in x]
36-
48+
49+
# Extract features across each of the three plane
50+
# using the three pre-trained AlexNet models defined earlier
3751
image1 = self.axial(images[0])
3852
image2 = self.coronal(images[1])
3953
image3 = self.saggital(images[2])
4054

55+
# Convert the image dimesnsions from [slices, 256, 1, 1] to
56+
# [slices,256]
4157
image1 = self.pool_axial(image1).view(image1.size(0), -1)
4258
image2 = self.pool_coronal(image2).view(image2.size(0), -1)
4359
image3 = self.pool_saggital(image3).view(image3.size(0), -1)
4460

61+
# Find maximum value in each slice
62+
# This will reduce the dimensions of image to [1,256]
63+
# This is done in order to keep only the most prevalent
64+
# features for each slice
4565
image1 = torch.max(image1,dim=0,keepdim=True)[0]
4666
image2 = torch.max(image2,dim=0,keepdim=True)[0]
4767
image3 = torch.max(image3,dim=0,keepdim=True)[0]
4868

69+
# Stack the 3 images together to create the output
70+
# of size [1, 256*3]
4971
output = torch.cat([image1,image2,image3], dim=1)
5072

73+
# Feed the output to the sequential network created earlier
74+
# The network will return a probability of having a specific
75+
# disease
5176
output = self.fc(output)
5277
return output
5378

5479
def _load_wieghts(self):
5580
"""load pretrained weights"""
56-
pass
81+
pass

MRNet-Single-Model/utils/utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,49 @@ def _evaluate_model(model, val_loader, criterion, epoch, num_epochs, writer, cur
1515

1616
# Set to eval mode
1717
model.eval()
18-
18+
# List of probabilities obtained from the model
1919
y_probs = []
20+
# List of groundtruth labels
2021
y_gt = []
22+
# List of losses obtained
2123
losses = []
2224

25+
# Iterate over the validation dataset
2326
for i, (images, label) in enumerate(val_loader):
24-
27+
# If GPU is available, load the images and label
28+
# on GPU
2529
if torch.cuda.is_available():
2630
images = [image.cuda() for image in images]
2731
label = label.cuda()
2832

33+
# Obtain the model output by passing the images as input
2934
output = model(images)
30-
35+
# Evaluate the loss by comparing the output and groundtruth label
3136
loss = criterion(output, label)
32-
37+
# Add loss to the list of losses
3338
loss_value = loss.item()
3439
losses.append(loss_value)
35-
40+
# Find probability for each class by applying
41+
# sigmoid function on model output
3642
probas = torch.sigmoid(output)
37-
43+
# Add the groundtruth to the list of groundtruths
3844
y_gt.append(int(label.item()))
45+
# Add predicted probability to the list
3946
y_probs.append(probas.item())
4047

4148
try:
49+
# Evaluate area under ROC curve based on the groundtruth label
50+
# and predicted probability
4251
auc = metrics.roc_auc_score(y_gt, y_probs)
4352
except:
53+
# Default area under ROC curve
4454
auc = 0.5
45-
55+
# Add information to the writer about validation loss and Area under ROC curve
4656
writer.add_scalar('Val/Loss', loss_value, epoch * len(val_loader) + i)
4757
writer.add_scalar('Val/AUC', auc, epoch * len(val_loader) + i)
4858

4959
if (i % log_every == 0) & (i > 0):
60+
# Display the information about average validation loss and area under ROC curve
5061
print('''[Epoch: {0} / {1} | Batch : {2} / {3} ]| Avg Val Loss {4} | Val AUC : {5} | lr : {6}'''.
5162
format(
5263
epoch + 1,
@@ -58,9 +69,9 @@ def _evaluate_model(model, val_loader, criterion, epoch, num_epochs, writer, cur
5869
current_lr
5970
)
6071
)
61-
72+
# Add information to the writer about total epochs and Area under ROC curve
6273
writer.add_scalar('Val/AUC_epoch', auc, epoch + i)
63-
74+
# Find mean area under ROC curve and validation loss
6475
val_loss_epoch = np.round(np.mean(losses), 4)
6576
val_auc_epoch = np.round(auc, 4)
6677

@@ -71,41 +82,62 @@ def _train_model(model, train_loader, epoch, num_epochs, optimizer, criterion, w
7182
# Set to train mode
7283
model.train()
7384

85+
# Initialize the predicted probabilities
7486
y_probs = []
87+
# Initialize the groundtruth labels
7588
y_gt = []
89+
# Initialize the loss between the groundtruth label
90+
# and the predicted probability
7691
losses = []
7792

93+
# Iterate over the training dataset
7894
for i, (images, label) in enumerate(train_loader):
95+
# Reset the gradient by zeroing it
7996
optimizer.zero_grad()
80-
97+
98+
# If GPU is available, transfer the images and label
99+
# to the GPU
81100
if torch.cuda.is_available():
82101
images = [image.cuda() for image in images]
83102
label = label.cuda()
84103

104+
# Obtain the prediction using the model
85105
output = model(images)
86106

107+
# Evaluate the loss by comparing the prediction
108+
# and groundtruth label
87109
loss = criterion(output, label)
110+
# Perform a backward propagation
88111
loss.backward()
112+
# Modify the weights based on the error gradient
89113
optimizer.step()
90114

115+
# Add current loss to the list of losses
91116
loss_value = loss.item()
92117
losses.append(loss_value)
93118

119+
# Find probabilities from output using sigmoid function
94120
probas = torch.sigmoid(output)
95121

122+
# Add current groundtruth label to the list of groundtruths
96123
y_gt.append(int(label.item()))
124+
# Add current probabilities to the list of probabilities
97125
y_probs.append(probas.item())
98126

99127
try:
128+
# Try finding the area under ROC curve
100129
auc = metrics.roc_auc_score(y_gt, y_probs)
101130
except:
131+
# Use default value of area under ROC curve as 0.5
102132
auc = 0.5
103-
133+
134+
# Add information to the writer about training loss and Area under ROC curve
104135
writer.add_scalar('Train/Loss', loss_value,
105136
epoch * len(train_loader) + i)
106137
writer.add_scalar('Train/AUC', auc, epoch * len(train_loader) + i)
107138

108139
if (i % log_every == 0) & (i > 0):
140+
# Display the information about average training loss and area under ROC curve
109141
print('''[Epoch: {0} / {1} | Batch : {2} / {3} ]| Avg Train Loss {4} | Train AUC : {5} | lr : {6}'''.
110142
format(
111143
epoch + 1,
@@ -117,9 +149,10 @@ def _train_model(model, train_loader, epoch, num_epochs, optimizer, criterion, w
117149
current_lr
118150
)
119151
)
120-
152+
# Add information to the writer about total epochs and Area under ROC curve
121153
writer.add_scalar('Train/AUC_epoch', auc, epoch + i)
122154

155+
# Find mean area under ROC curve and training loss
123156
train_loss_epoch = np.round(np.mean(losses), 4)
124157
train_auc_epoch = np.round(auc, 4)
125158

0 commit comments

Comments
 (0)