Skip to content

Commit 7eb9e65

Browse files
neoglezsoumith
authored andcommitted
Add SEMEION Dataset (#324)
* First commit for semeion dataset SEMEION Handwritten Digits Data Set http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit * SEMEION Class * Fix Lint errors * More linting errors
1 parent 2091365 commit 7eb9e65

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from .svhn import SVHN
88
from .phototour import PhotoTour
99
from .fakedata import FakeData
10+
from .semeion import SEMEION
1011

1112
__all__ = ('LSUN', 'LSUNClass',
1213
'ImageFolder', 'FakeData',
1314
'CocoCaptions', 'CocoDetection',
1415
'CIFAR10', 'CIFAR100', 'FashionMNIST',
15-
'MNIST', 'STL10', 'SVHN', 'PhotoTour')
16+
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')

torchvision/datasets/semeion.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import print_function
2+
from PIL import Image
3+
import os
4+
import os.path
5+
import errno
6+
import numpy as np
7+
import sys
8+
if sys.version_info[0] == 2:
9+
import cPickle as pickle
10+
else:
11+
import pickle
12+
13+
import torch.utils.data as data
14+
from .utils import download_url, check_integrity
15+
16+
17+
class SEMEION(data.Dataset):
18+
"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
19+
Args:
20+
root (string): Root directory of dataset where directory
21+
``semeion.py`` exists.
22+
transform (callable, optional): A function/transform that takes in an PIL image
23+
and returns a transformed version. E.g, ``transforms.RandomCrop``
24+
target_transform (callable, optional): A function/transform that takes in the
25+
target and transforms it.
26+
download (bool, optional): If true, downloads the dataset from the internet and
27+
puts it in root directory. If dataset is already downloaded, it is not
28+
downloaded again.
29+
"""
30+
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
31+
filename = "semeion.data"
32+
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
33+
34+
def __init__(self, root, transform=None, target_transform=None, download=True):
35+
self.root = os.path.expanduser(root)
36+
self.transform = transform
37+
self.target_transform = target_transform
38+
39+
if download:
40+
self.download()
41+
42+
if not self._check_integrity():
43+
raise RuntimeError('Dataset not found or corrupted.' +
44+
' You can use download=True to download it')
45+
46+
self.data = []
47+
self.labels = []
48+
fp = os.path.join(root, self.filename)
49+
file = open(fp, 'r')
50+
data = file.read()
51+
file.close()
52+
dataSplitted = data.split("\n")[:-1]
53+
datasetLength = len(dataSplitted)
54+
i = 0
55+
while i < datasetLength:
56+
# Get the 'i-th' row
57+
strings = dataSplitted[i]
58+
59+
# Split row into numbers(string), and avoid blank at the end
60+
stringsSplitted = (strings[:-1]).split(" ")
61+
62+
# Get data (which ends at column 256th), then in a numpy array.
63+
rawData = stringsSplitted[:256]
64+
dataFloat = [float(j) for j in rawData]
65+
img = np.array(dataFloat[:16])
66+
j = 16
67+
k = 0
68+
while j < len(dataFloat):
69+
temp = np.array(dataFloat[k:j])
70+
img = np.vstack((img, temp))
71+
72+
k = j
73+
j += 16
74+
75+
self.data.append(img)
76+
77+
# Get label and convert it into numbers, then in a numpy array.
78+
labelString = stringsSplitted[256:]
79+
labelInt = [int(index) for index in labelString]
80+
self.labels.append(np.array(labelInt))
81+
i += 1
82+
83+
def __getitem__(self, index):
84+
"""
85+
Args:
86+
index (int): Index
87+
Returns:
88+
tuple: (image, target) where target is index of the target class.
89+
"""
90+
img, target = self.data[index], self.labels[index]
91+
92+
# doing this so that it is consistent with all other datasets
93+
# to return a PIL Image
94+
# convert value to 8 bit unsigned integer
95+
# color (white #255) the pixels
96+
img = img.astype('uint8') * 255
97+
img = Image.fromarray(img, mode='L')
98+
99+
if self.transform is not None:
100+
img = self.transform(img)
101+
102+
if self.target_transform is not None:
103+
target = self.target_transform(target)
104+
105+
return img, target
106+
107+
def __len__(self):
108+
return len(self.data)
109+
110+
def _check_integrity(self):
111+
root = self.root
112+
fpath = os.path.join(root, self.filename)
113+
if not check_integrity(fpath, self.md5_checksum):
114+
return False
115+
return True
116+
117+
def download(self):
118+
if self._check_integrity():
119+
print('Files already downloaded and verified')
120+
return
121+
122+
root = self.root
123+
download_url(self.url, root, self.filename, self.md5_checksum)

0 commit comments

Comments
 (0)