Skip to content

Commit 3e07dd9

Browse files
Update dataset.py
1 parent b0f546b commit 3e07dd9

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

MRNet-Single-Model/dataset/dataset.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,52 +19,58 @@ class MRData():
1919

2020
def __init__(self,task = 'acl', train = True, transform = None, weights = None):
2121
"""Initialize the dataset
22-
2322
Args:
2423
plane : along which plane to load the data
2524
task : for which task to load the labels
2625
train : whether to load the train or val data
2726
transform : which transforms to apply
2827
weights (Tensor) : Give wieghted loss to postive class eg. `weights=torch.tensor([2.223])`
2928
"""
29+
# Define the three planes to use
3030
self.planes=['axial', 'coronal', 'sagittal']
31+
# Initialize the records as None
3132
self.records = None
3233
# an empty dictionary
3334
self.image_path={}
3435

36+
# If we are in training loop
3537
if train:
38+
# Read data about patient records
3639
self.records = pd.read_csv('./images/train-{}.csv'.format(task),header=None, names=['id', 'label'])
3740

38-
'''
39-
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
40-
image for that plane}
41-
'''
4241
for plane in self.planes:
42+
# For each plane, specify the image path
4343
self.image_path[plane] = './images/train/{}/'.format(plane)
4444
else:
45+
# If we are in testing loop
46+
# don't use any transformation
4547
transform = None
48+
# Read testing/validation data (patients records)
4649
self.records = pd.read_csv('./images/valid-{}.csv'.format(task),header=None, names=['id', 'label'])
47-
'''
48-
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
49-
image for that plane}
50-
'''
50+
5151
for plane in self.planes:
52+
# Read path of images for each plane
5253
self.image_path[plane] = './images/valid/{}/'.format(plane)
5354

54-
55+
# Initialize the transformation to apply on images
5556
self.transform = transform
5657

58+
# Append 0s to the patient record id
5759
self.records['id'] = self.records['id'].map(
5860
lambda i: '0' * (4 - len(str(i))) + str(i))
5961
# empty dictionary
60-
self.paths={}
62+
self.paths={}
6163
for plane in self.planes:
64+
# Get paths of numpy data files for each plane
6265
self.paths[plane] = [self.image_path[plane] + filename +
6366
'.npy' for filename in self.records['id'].tolist()]
6467

68+
# Convert labels from Pandas Series to a list
6569
self.labels = self.records['label'].tolist()
6670

71+
# Total positive cases
6772
pos = sum(self.labels)
73+
# Total negative cases
6874
neg = len(self.labels) - pos
6975

7076
# Find the wieghts of pos and neg classes
@@ -90,53 +96,75 @@ def __getitem__(self, index):
9096
img_raw = {}
9197

9298
for plane in self.planes:
99+
# Load raw image data for each plane
93100
img_raw[plane] = np.load(self.paths[plane][index])
101+
# Resize the image loaded in the previous step
94102
img_raw[plane] = self._resize_image(img_raw[plane])
95103

96104
label = self.labels[index]
105+
# Convert label to 0 and 1
97106
if label == 1:
98107
label = torch.FloatTensor([1])
99108
elif label == 0:
100109
label = torch.FloatTensor([0])
101110

111+
# Return a list of three images for three planes and the label of the record
102112
return [img_raw[plane] for plane in self.planes], label
103113

104114
def _resize_image(self, image):
105115
"""Resize the image to `(3,224,224)` and apply
106116
transforms if possible.
107117
"""
108118
# Resize the image
119+
# Calculate extra padding present in the image
120+
# which needs to be removed
109121
pad = int((image.shape[2] - INPUT_DIM)/2)
122+
# This is equivalent to center cropping the image
110123
image = image[:,pad:-pad,pad:-pad]
124+
# Normalize the image by subtracting it by mean and dividing by standard
125+
# deviation
111126
image = (image-np.min(image))/(np.max(image)-np.min(image))*MAX_PIXEL_VAL
112127
image = (image - MEAN) / STDDEV
113-
128+
129+
# If the transformation is not None
114130
if self.transform:
131+
# Transform the image based on the specified transformation
115132
image = self.transform(image)
116133
else:
134+
# Else, just stack the image with itself in order to match the required
135+
# dimensions
117136
image = np.stack((image,)*3, axis=1)
118-
137+
# Convert the image to a FloatTensor and return it
119138
image = torch.FloatTensor(image)
120139
return image
121140

122141
def load_data(task : str):
123142

124143
# Define the Augmentation here only
125144
augments = Compose([
145+
# Convert the image to Tensor
126146
transforms.Lambda(lambda x: torch.Tensor(x)),
147+
# Randomly rotate the image with an angle
148+
# between -25 degrees to 25 degrees
127149
RandomRotate(25),
150+
# Randomly translate the image by 11% of
151+
# image height and width
128152
RandomTranslate([0.11, 0.11]),
153+
# Randomly flip the image
129154
RandomFlip(),
155+
# Change the order of image channels
130156
transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
131157
])
132158

133159
print('Loading Train Dataset of {} task...'.format(task))
160+
# Load training dataset
134161
train_data = MRData(task, train=True, transform=augments)
135162
train_loader = data.DataLoader(
136163
train_data, batch_size=1, num_workers=11, shuffle=True
137164
)
138165

139166
print('Loading Validation Dataset of {} task...'.format(task))
167+
# Load validation dataset
140168
val_data = MRData(task, train=False)
141169
val_loader = data.DataLoader(
142170
val_data, batch_size=1, num_workers=11, shuffle=False

0 commit comments

Comments
 (0)