@@ -19,52 +19,58 @@ class MRData():
19
19
20
20
def __init__ (self ,task = 'acl' , train = True , transform = None , weights = None ):
21
21
"""Initialize the dataset
22
-
23
22
Args:
24
23
plane : along which plane to load the data
25
24
task : for which task to load the labels
26
25
train : whether to load the train or val data
27
26
transform : which transforms to apply
28
27
weights (Tensor) : Give wieghted loss to postive class eg. `weights=torch.tensor([2.223])`
29
28
"""
29
+ # Define the three planes to use
30
30
self .planes = ['axial' , 'coronal' , 'sagittal' ]
31
+ # Initialize the records as None
31
32
self .records = None
32
33
# an empty dictionary
33
34
self .image_path = {}
34
35
36
+ # If we are in training loop
35
37
if train :
38
+ # Read data about patient records
36
39
self .records = pd .read_csv ('./images/train-{}.csv' .format (task ),header = None , names = ['id' , 'label' ])
37
40
38
- '''
39
- self.image_path[<plane>]= dictionary {<plane>: path to folder containing
40
- image for that plane}
41
- '''
42
41
for plane in self .planes :
42
+ # For each plane, specify the image path
43
43
self .image_path [plane ] = './images/train/{}/' .format (plane )
44
44
else :
45
+ # If we are in testing loop
46
+ # don't use any transformation
45
47
transform = None
48
+ # Read testing/validation data (patients records)
46
49
self .records = pd .read_csv ('./images/valid-{}.csv' .format (task ),header = None , names = ['id' , 'label' ])
47
- '''
48
- self.image_path[<plane>]= dictionary {<plane>: path to folder containing
49
- image for that plane}
50
- '''
50
+
51
51
for plane in self .planes :
52
+ # Read path of images for each plane
52
53
self .image_path [plane ] = './images/valid/{}/' .format (plane )
53
54
54
-
55
+ # Initialize the transformation to apply on images
55
56
self .transform = transform
56
57
58
+ # Append 0s to the patient record id
57
59
self .records ['id' ] = self .records ['id' ].map (
58
60
lambda i : '0' * (4 - len (str (i ))) + str (i ))
59
61
# empty dictionary
60
- self .paths = {}
62
+ self .paths = {}
61
63
for plane in self .planes :
64
+ # Get paths of numpy data files for each plane
62
65
self .paths [plane ] = [self .image_path [plane ] + filename +
63
66
'.npy' for filename in self .records ['id' ].tolist ()]
64
67
68
+ # Convert labels from Pandas Series to a list
65
69
self .labels = self .records ['label' ].tolist ()
66
70
71
+ # Total positive cases
67
72
pos = sum (self .labels )
73
+ # Total negative cases
68
74
neg = len (self .labels ) - pos
69
75
70
76
# Find the wieghts of pos and neg classes
@@ -90,53 +96,75 @@ def __getitem__(self, index):
90
96
img_raw = {}
91
97
92
98
for plane in self .planes :
99
+ # Load raw image data for each plane
93
100
img_raw [plane ] = np .load (self .paths [plane ][index ])
101
+ # Resize the image loaded in the previous step
94
102
img_raw [plane ] = self ._resize_image (img_raw [plane ])
95
103
96
104
label = self .labels [index ]
105
+ # Convert label to 0 and 1
97
106
if label == 1 :
98
107
label = torch .FloatTensor ([1 ])
99
108
elif label == 0 :
100
109
label = torch .FloatTensor ([0 ])
101
110
111
+ # Return a list of three images for three planes and the label of the record
102
112
return [img_raw [plane ] for plane in self .planes ], label
103
113
104
114
def _resize_image (self , image ):
105
115
"""Resize the image to `(3,224,224)` and apply
106
116
transforms if possible.
107
117
"""
108
118
# Resize the image
119
+ # Calculate extra padding present in the image
120
+ # which needs to be removed
109
121
pad = int ((image .shape [2 ] - INPUT_DIM )/ 2 )
122
+ # This is equivalent to center cropping the image
110
123
image = image [:,pad :- pad ,pad :- pad ]
124
+ # Normalize the image by subtracting it by mean and dividing by standard
125
+ # deviation
111
126
image = (image - np .min (image ))/ (np .max (image )- np .min (image ))* MAX_PIXEL_VAL
112
127
image = (image - MEAN ) / STDDEV
113
-
128
+
129
+ # If the transformation is not None
114
130
if self .transform :
131
+ # Transform the image based on the specified transformation
115
132
image = self .transform (image )
116
133
else :
134
+ # Else, just stack the image with itself in order to match the required
135
+ # dimensions
117
136
image = np .stack ((image ,)* 3 , axis = 1 )
118
-
137
+ # Convert the image to a FloatTensor and return it
119
138
image = torch .FloatTensor (image )
120
139
return image
121
140
122
141
def load_data (task : str ):
123
142
124
143
# Define the Augmentation here only
125
144
augments = Compose ([
145
+ # Convert the image to Tensor
126
146
transforms .Lambda (lambda x : torch .Tensor (x )),
147
+ # Randomly rotate the image with an angle
148
+ # between -25 degrees to 25 degrees
127
149
RandomRotate (25 ),
150
+ # Randomly translate the image by 11% of
151
+ # image height and width
128
152
RandomTranslate ([0.11 , 0.11 ]),
153
+ # Randomly flip the image
129
154
RandomFlip (),
155
+ # Change the order of image channels
130
156
transforms .Lambda (lambda x : x .repeat (3 , 1 , 1 , 1 ).permute (1 , 0 , 2 , 3 )),
131
157
])
132
158
133
159
print ('Loading Train Dataset of {} task...' .format (task ))
160
+ # Load training dataset
134
161
train_data = MRData (task , train = True , transform = augments )
135
162
train_loader = data .DataLoader (
136
163
train_data , batch_size = 1 , num_workers = 11 , shuffle = True
137
164
)
138
165
139
166
print ('Loading Validation Dataset of {} task...' .format (task ))
167
+ # Load validation dataset
140
168
val_data = MRData (task , train = False )
141
169
val_loader = data .DataLoader (
142
170
val_data , batch_size = 1 , num_workers = 11 , shuffle = False
0 commit comments