Skip to content

Commit 3d17530

Browse files
committed
Fix examples and remove redundant code with scikit-learn
1 parent d0283cc commit 3d17530

File tree

3 files changed

+17
-501
lines changed

3 files changed

+17
-501
lines changed

examples/plot_digits_classification_fastfood.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
# Import datasets, classifiers and performance metrics
2626
from sklearn import datasets, svm, metrics
27-
from sklearn.kernel_approximation import Fastfood
27+
28+
from sklearn_extra.kernel_approximation import Fastfood
2829

2930
# The digits dataset
3031
digits = datasets.load_digits()
@@ -34,11 +35,13 @@
3435
# attribute of the dataset. If we were working from image files, we
3536
# could load them using pylab.imread. For these images know which
3637
# digit they represent: it is given in the 'target' of the dataset.
37-
for index, (image, label) in enumerate(zip(digits.images, digits.target)[:4]):
38+
for index, (image, label) in enumerate(zip(digits.images, digits.target)):
3839
pl.subplot(2, 4, index + 1)
3940
pl.axis('off')
4041
pl.imshow(image, cmap=pl.cm.gray_r, interpolation='nearest')
4142
pl.title('Training: %i' % label)
43+
if index > 3:
44+
break
4245

4346
# To apply an classifier on this data, we need to flatten the image, to
4447
# turn the data in a (samples, feature) matrix:
@@ -47,8 +50,8 @@
4750
gamma = .001
4851
sigma = np.sqrt(1 / (2 * gamma))
4952
number_of_features_to_generate = 1000
50-
train__idx = range(n_samples / 2)
51-
test__idx = range(n_samples / 2, n_samples)
53+
train__idx = range(n_samples // 2)
54+
test__idx = range(n_samples // 2, n_samples)
5255

5356
# map data into featurespace
5457
rbf_transform = Fastfood(
@@ -94,10 +97,12 @@
9497
% metrics.confusion_matrix(expected, predicted_linear_transformed))
9598

9699
for index, (image, prediction) in enumerate(
97-
zip(digits.images[test__idx], predicted)[:4]):
98-
pl.subplot(2, 4, index + 5)
100+
zip(digits.images[test__idx], predicted)):
101+
pl.subplot(2, 4, index + 4)
99102
pl.axis('off')
100103
pl.imshow(image, cmap=pl.cm.gray_r, interpolation='nearest')
101104
pl.title('Prediction: %i' % prediction)
105+
if index > 3:
106+
break
102107

103108
pl.show()

examples/plot_kernel_approximation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@
5454

5555
# Import datasets, classifiers and performance metrics
5656
from sklearn import datasets, svm, pipeline
57-
from sklearn.kernel_approximation import (RBFSampler,
58-
Nystroem, Fastfood)
57+
from sklearn.kernel_approximation import Nystroem, RBFSampler
5958
from sklearn.decomposition import PCA
6059

60+
from sklearn_extra.kernel_approximation import Fastfood
61+
6162
# The digits dataset
6263
digits = datasets.load_digits(n_class=9)
6364

@@ -68,11 +69,11 @@
6869
data -= data.mean(axis=0)
6970

7071
# We learn the digits on the first half of the digits
71-
data_train, targets_train = data[:n_samples / 2], digits.target[:n_samples / 2]
72+
data_train, targets_train = data[:n_samples // 2], digits.target[:n_samples // 2]
7273

7374

7475
# Now predict the value of the digit on the second half:
75-
data_test, targets_test = data[n_samples / 2:], digits.target[n_samples / 2:]
76+
data_test, targets_test = data[n_samples // 2:], digits.target[n_samples // 2:]
7677
#data_test = scaler.transform(data_test)
7778

7879
# fix model parameters:

0 commit comments

Comments
 (0)