@@ -46,39 +46,12 @@ def __init__(self, root, transform=None, target_transform=None, download=True):
46
46
self .data = []
47
47
self .labels = []
48
48
fp = os .path .join (root , self .filename )
49
- file = open (fp , 'r' )
50
- data = file .read ()
51
- file .close ()
52
- dataSplitted = data .split ("\n " )[:- 1 ]
53
- datasetLength = len (dataSplitted )
54
- i = 0
55
- while i < datasetLength :
56
- # Get the 'i-th' row
57
- strings = dataSplitted [i ]
58
-
59
- # Split row into numbers(string), and avoid blank at the end
60
- stringsSplitted = (strings [:- 1 ]).split (" " )
61
-
62
- # Get data (which ends at column 256th), then in a numpy array.
63
- rawData = stringsSplitted [:256 ]
64
- dataFloat = [float (j ) for j in rawData ]
65
- img = np .array (dataFloat [:16 ])
66
- j = 16
67
- k = 0
68
- while j < len (dataFloat ):
69
- temp = np .array (dataFloat [k :j ])
70
- img = np .vstack ((img , temp ))
71
-
72
- k = j
73
- j += 16
74
-
75
- self .data .append (img )
76
-
77
- # Get label and convert it into numbers, then in a numpy array.
78
- labelString = stringsSplitted [256 :]
79
- labelInt = [int (index ) for index in labelString ]
80
- self .labels .append (np .array (labelInt ))
81
- i += 1
49
+ data = np .loadtxt (fp )
50
+ # convert value to 8 bit unsigned integer
51
+ # color (white #255) the pixels
52
+ self .data = (data [:, :256 ] * 255 ).astype ('uint8' )
53
+ self .data = np .reshape (self .data , (- 1 , 16 , 16 ))
54
+ self .labels = np .nonzero (data [:, 256 :])[1 ]
82
55
83
56
def __getitem__ (self , index ):
84
57
"""
@@ -91,9 +64,6 @@ def __getitem__(self, index):
91
64
92
65
# doing this so that it is consistent with all other datasets
93
66
# to return a PIL Image
94
- # convert value to 8 bit unsigned integer
95
- # color (white #255) the pixels
96
- img = img .astype ('uint8' ) * 255
97
67
img = Image .fromarray (img , mode = 'L' )
98
68
99
69
if self .transform is not None :
0 commit comments