@@ -7,19 +7,31 @@ class MRnet(nn.Module):
7
7
"""MRnet uses pretrained resnet50 as a backbone to extract features
8
8
"""
9
9
10
- def __init__ (self ): # add conf file
11
-
10
+ def __init__ (self ):
11
+ """This function will be used to initialize the
12
+ MRnet instance."""
13
+ # Initialize nn.Module instance
12
14
super (MRnet ,self ).__init__ ()
13
15
14
- # init three backbones for three axis
16
+ # Initialize three backbones for three axis
17
+ # All the three axes will use pretrained AlexNet model
18
+ # The models will be used for extracting features from
19
+ # the input images
15
20
self .axial = models .alexnet (pretrained = True ).features
16
21
self .coronal = models .alexnet (pretrained = True ).features
17
22
self .saggital = models .alexnet (pretrained = True ).features
18
-
23
+
24
+ # Initialize 2D Adaptive Average Pooling layers
25
+ # The pooling layers will reduce the size of
26
+ # feature maps extracted from the previous axes
19
27
self .pool_axial = nn .AdaptiveAvgPool2d (1 )
20
28
self .pool_coronal = nn .AdaptiveAvgPool2d (1 )
21
29
self .pool_saggital = nn .AdaptiveAvgPool2d (1 )
22
-
30
+
31
+ # Initialize a sequential neural network with
32
+ # a single fully connected linear layer
33
+ # The network will output the probability of
34
+ # having a particular disease
23
35
self .fc = nn .Sequential (
24
36
nn .Linear (in_features = 3 * 256 ,out_features = 1 )
25
37
)
@@ -33,24 +45,37 @@ def forward(self,x):
33
45
# squeeze the first dimension as there
34
46
# is only one patient in each batch
35
47
images = [torch .squeeze (img , dim = 0 ) for img in x ]
36
-
48
+
49
+ # Extract features across each of the three plane
50
+ # using the three pre-trained AlexNet models defined earlier
37
51
image1 = self .axial (images [0 ])
38
52
image2 = self .coronal (images [1 ])
39
53
image3 = self .saggital (images [2 ])
40
54
55
+ # Convert the image dimesnsions from [slices, 256, 1, 1] to
56
+ # [slices,256]
41
57
image1 = self .pool_axial (image1 ).view (image1 .size (0 ), - 1 )
42
58
image2 = self .pool_coronal (image2 ).view (image2 .size (0 ), - 1 )
43
59
image3 = self .pool_saggital (image3 ).view (image3 .size (0 ), - 1 )
44
60
61
+ # Find maximum value in each slice
62
+ # This will reduce the dimensions of image to [1,256]
63
+ # This is done in order to keep only the most prevalent
64
+ # features for each slice
45
65
image1 = torch .max (image1 ,dim = 0 ,keepdim = True )[0 ]
46
66
image2 = torch .max (image2 ,dim = 0 ,keepdim = True )[0 ]
47
67
image3 = torch .max (image3 ,dim = 0 ,keepdim = True )[0 ]
48
68
69
+ # Stack the 3 images together to create the output
70
+ # of size [1, 256*3]
49
71
output = torch .cat ([image1 ,image2 ,image3 ], dim = 1 )
50
72
73
+ # Feed the output to the sequential network created earlier
74
+ # The network will return a probability of having a specific
75
+ # disease
51
76
output = self .fc (output )
52
77
return output
53
78
54
79
def _load_wieghts (self ):
55
80
"""load pretrained weights"""
56
- pass
81
+ pass
0 commit comments