Skip to content

Commit b5f0b6e

Browse files
authored
Fix and simplify SEMEION dataset (#332)
1 parent 6497852 commit b5f0b6e

File tree

1 file changed

+6
-36
lines changed

1 file changed

+6
-36
lines changed

torchvision/datasets/semeion.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,39 +46,12 @@ def __init__(self, root, transform=None, target_transform=None, download=True):
4646
self.data = []
4747
self.labels = []
4848
fp = os.path.join(root, self.filename)
49-
file = open(fp, 'r')
50-
data = file.read()
51-
file.close()
52-
dataSplitted = data.split("\n")[:-1]
53-
datasetLength = len(dataSplitted)
54-
i = 0
55-
while i < datasetLength:
56-
# Get the 'i-th' row
57-
strings = dataSplitted[i]
58-
59-
# Split row into numbers(string), and avoid blank at the end
60-
stringsSplitted = (strings[:-1]).split(" ")
61-
62-
# Get data (which ends at column 256th), then in a numpy array.
63-
rawData = stringsSplitted[:256]
64-
dataFloat = [float(j) for j in rawData]
65-
img = np.array(dataFloat[:16])
66-
j = 16
67-
k = 0
68-
while j < len(dataFloat):
69-
temp = np.array(dataFloat[k:j])
70-
img = np.vstack((img, temp))
71-
72-
k = j
73-
j += 16
74-
75-
self.data.append(img)
76-
77-
# Get label and convert it into numbers, then in a numpy array.
78-
labelString = stringsSplitted[256:]
79-
labelInt = [int(index) for index in labelString]
80-
self.labels.append(np.array(labelInt))
81-
i += 1
49+
data = np.loadtxt(fp)
50+
# convert value to 8 bit unsigned integer
51+
# color (white #255) the pixels
52+
self.data = (data[:, :256] * 255).astype('uint8')
53+
self.data = np.reshape(self.data, (-1, 16, 16))
54+
self.labels = np.nonzero(data[:, 256:])[1]
8255

8356
def __getitem__(self, index):
8457
"""
@@ -91,9 +64,6 @@ def __getitem__(self, index):
9164

9265
# doing this so that it is consistent with all other datasets
9366
# to return a PIL Image
94-
# convert value to 8 bit unsigned integer
95-
# color (white #255) the pixels
96-
img = img.astype('uint8') * 255
9767
img = Image.fromarray(img, mode='L')
9868

9969
if self.transform is not None:

0 commit comments

Comments
 (0)