Skip to content

Commit b0f546b

Browse files
Update MRnet.py
1 parent ef7ef28 commit b0f546b

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

MRNet-Single-Model/models/MRnet.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,31 @@ class MRnet(nn.Module):
77
"""MRnet uses pretrained resnet50 as a backbone to extract features
88
"""
99

10-
def __init__(self): # add conf file
11-
10+
def __init__(self):
11+
"""This function will be used to initialize the
12+
MRnet instance."""
13+
# Initialize nn.Module instance
1214
super(MRnet,self).__init__()
1315

14-
# init three backbones for three axis
16+
# Initialize three backbones for three axis
17+
# All the three axes will use pretrained AlexNet model
18+
# The models will be used for extracting features from
19+
# the input images
1520
self.axial = models.alexnet(pretrained=True).features
1621
self.coronal = models.alexnet(pretrained=True).features
1722
self.saggital = models.alexnet(pretrained=True).features
18-
23+
24+
# Initialize 2D Adaptive Average Pooling layers
25+
# The pooling layers will reduce the size of
26+
# feature maps extracted from the previous axes
1927
self.pool_axial = nn.AdaptiveAvgPool2d(1)
2028
self.pool_coronal = nn.AdaptiveAvgPool2d(1)
2129
self.pool_saggital = nn.AdaptiveAvgPool2d(1)
22-
30+
31+
# Initialize a sequential neural network with
32+
# a single fully connected linear layer
33+
# The network will output the probability of
34+
# having a particular disease
2335
self.fc = nn.Sequential(
2436
nn.Linear(in_features=3*256,out_features=1)
2537
)
@@ -33,24 +45,37 @@ def forward(self,x):
3345
# squeeze the first dimension as there
3446
# is only one patient in each batch
3547
images = [torch.squeeze(img, dim=0) for img in x]
36-
48+
49+
# Extract features across each of the three plane
50+
# using the three pre-trained AlexNet models defined earlier
3751
image1 = self.axial(images[0])
3852
image2 = self.coronal(images[1])
3953
image3 = self.saggital(images[2])
4054

55+
# Convert the image dimesnsions from [slices, 256, 1, 1] to
56+
# [slices,256]
4157
image1 = self.pool_axial(image1).view(image1.size(0), -1)
4258
image2 = self.pool_coronal(image2).view(image2.size(0), -1)
4359
image3 = self.pool_saggital(image3).view(image3.size(0), -1)
4460

61+
# Find maximum value in each slice
62+
# This will reduce the dimensions of image to [1,256]
63+
# This is done in order to keep only the most prevalent
64+
# features for each slice
4565
image1 = torch.max(image1,dim=0,keepdim=True)[0]
4666
image2 = torch.max(image2,dim=0,keepdim=True)[0]
4767
image3 = torch.max(image3,dim=0,keepdim=True)[0]
4868

69+
# Stack the 3 images together to create the output
70+
# of size [1, 256*3]
4971
output = torch.cat([image1,image2,image3], dim=1)
5072

73+
# Feed the output to the sequential network created earlier
74+
# The network will return a probability of having a specific
75+
# disease
5176
output = self.fc(output)
5277
return output
5378

5479
def _load_wieghts(self):
5580
"""load pretrained weights"""
56-
pass
81+
pass

0 commit comments

Comments
 (0)