|
24 | 24 |
|
25 | 25 | # Import datasets, classifiers and performance metrics
|
26 | 26 | from sklearn import datasets, svm, metrics
|
27 |
| -from sklearn.kernel_approximation import Fastfood |
| 27 | + |
| 28 | +from sklearn_extra.kernel_approximation import Fastfood |
28 | 29 |
|
29 | 30 | # The digits dataset
|
30 | 31 | digits = datasets.load_digits()
|
|
34 | 35 | # attribute of the dataset. If we were working from image files, we
|
35 | 36 | # could load them using pylab.imread. For these images know which
|
36 | 37 | # digit they represent: it is given in the 'target' of the dataset.
|
37 |
| -for index, (image, label) in enumerate(zip(digits.images, digits.target)[:4]): |
| 38 | +for index, (image, label) in enumerate(zip(digits.images, digits.target)): |
38 | 39 | pl.subplot(2, 4, index + 1)
|
39 | 40 | pl.axis('off')
|
40 | 41 | pl.imshow(image, cmap=pl.cm.gray_r, interpolation='nearest')
|
41 | 42 | pl.title('Training: %i' % label)
|
| 43 | + if index > 3: |
| 44 | + break |
42 | 45 |
|
43 | 46 | # To apply an classifier on this data, we need to flatten the image, to
|
44 | 47 | # turn the data in a (samples, feature) matrix:
|
|
47 | 50 | gamma = .001
|
48 | 51 | sigma = np.sqrt(1 / (2 * gamma))
|
49 | 52 | number_of_features_to_generate = 1000
|
50 |
| -train__idx = range(n_samples / 2) |
51 |
| -test__idx = range(n_samples / 2, n_samples) |
| 53 | +train__idx = range(n_samples // 2) |
| 54 | +test__idx = range(n_samples // 2, n_samples) |
52 | 55 |
|
53 | 56 | # map data into featurespace
|
54 | 57 | rbf_transform = Fastfood(
|
|
94 | 97 | % metrics.confusion_matrix(expected, predicted_linear_transformed))
|
95 | 98 |
|
96 | 99 | for index, (image, prediction) in enumerate(
|
97 |
| - zip(digits.images[test__idx], predicted)[:4]): |
98 |
| - pl.subplot(2, 4, index + 5) |
| 100 | + zip(digits.images[test__idx], predicted)): |
| 101 | + pl.subplot(2, 4, index + 4) |
99 | 102 | pl.axis('off')
|
100 | 103 | pl.imshow(image, cmap=pl.cm.gray_r, interpolation='nearest')
|
101 | 104 | pl.title('Prediction: %i' % prediction)
|
| 105 | + if index > 3: |
| 106 | + break |
102 | 107 |
|
103 | 108 | pl.show()
|
0 commit comments