Skip to content

Commit 12fb111

Browse files
Alex7Lirth
authored andcommitted
Change mnist example to a smaller synthetic example (#33)
1 parent 3694133 commit 12fb111

File tree

1 file changed

+46
-45
lines changed

1 file changed

+46
-45
lines changed
Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
2-
===============================================
3-
Comparison of EigenPro and SVC on Fashion-MNIST
4-
===============================================
2+
======================================================
3+
Comparison of EigenPro and SVC on Digit Classification
4+
======================================================
55
66
Here we train a EigenPro Classifier and a Support
7-
Vector Classifier (SVC) on subsets of MNIST of various sizes.
8-
We halt the training of EigenPro after two epochs.
9-
Experimental results on MNIST demonstrate more than 3 times
10-
speedup of EigenPro over SVC in training time. EigenPro also
11-
shows consistently lower classification error on test set.
7+
Vector Classifier (SVC) on a synthetically generated
8+
binary classification problem. We halt the training
9+
of EigenPro after two epochs.
10+
While EigenPro is slower on low dimensional datasets, as
11+
the number of features exceeds 500, it begins to outperform
12+
SVM in terms of both time and training error.
1213
"""
1314
print(__doc__)
1415

@@ -17,22 +18,14 @@
1718
import numpy as np
1819
from time import time
1920

21+
from sklearn.datasets import make_classification
2022
from sklearn_extra.kernel_methods import EigenProClassifier
2123
from sklearn.svm import SVC
22-
from sklearn.datasets import fetch_openml
2324

2425
rng = np.random.RandomState(1)
2526

26-
# Generate sample data from mnist
27-
mnist = fetch_openml("Fashion-MNIST")
28-
mnist.data = mnist.data / 255.0
29-
print("Data has loaded")
30-
31-
p = rng.permutation(60000)
32-
x_train = mnist.data[p]
33-
y_train = np.int32(mnist.target[p])
34-
x_test = mnist.data[60000:]
35-
y_test = np.int32(mnist.target[60000:])
27+
train_size = 2000
28+
test_size = 1000
3629

3730
# Run tests comparing eig to svc
3831
eig_fit_times = []
@@ -42,30 +35,38 @@
4235
svc_pred_times = []
4336
svc_err = []
4437

45-
train_sizes = [500, 1000, 2000]
46-
47-
print("Train Sizes: " + str(train_sizes))
48-
49-
bandwidth = 5.0
38+
feature_counts = [15, 50, 150, 500, 1500]
39+
bandwidth = 8.0
5040

5141
# Fit models to data
52-
for train_size in train_sizes:
42+
for n_features in feature_counts:
43+
x, y = make_classification(
44+
n_samples=train_size + test_size,
45+
n_features=n_features,
46+
random_state=rng,
47+
)
48+
49+
x_train = x[:train_size]
50+
y_train = y[:train_size]
51+
x_test = x[train_size:]
52+
y_test = y[train_size:]
5353
for name, estimator in [
5454
(
5555
"EigenPro",
5656
EigenProClassifier(
57-
n_epoch=2, bandwidth=bandwidth, random_state=rng
57+
n_epoch=2,
58+
bandwidth=bandwidth,
59+
n_components=400,
60+
random_state=rng,
5861
),
5962
),
6063
(
6164
"SupportVector",
62-
SVC(
63-
C=5, gamma=1.0 / (2 * bandwidth * bandwidth), random_state=rng
64-
),
65+
SVC(gamma=1.0 / (2 * bandwidth * bandwidth), random_state=rng),
6566
),
6667
]:
6768
stime = time()
68-
estimator.fit(x_train[:train_size], y_train[:train_size])
69+
estimator.fit(x_train, y_train)
6970
fit_t = time() - stime
7071

7172
stime = time()
@@ -82,35 +83,35 @@
8283
svc_pred_times.append(pred_t)
8384
svc_err.append(err)
8485
print(
85-
"%s Classification with %i training samples in %0.2f seconds."
86-
% (name, train_size, fit_t + pred_t)
86+
"%s Classification with %i features in %0.2f seconds. Error: %0.1f"
87+
% (name, n_features, fit_t + pred_t, err)
8788
)
8889

8990
# set up grid for figures
9091
fig = plt.figure(num=None, figsize=(6, 4), dpi=160)
9192
ax = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
9293

9394
# Graph fit(train) time
94-
train_size_labels = [str(s) for s in train_sizes]
95-
ax.plot(train_sizes, svc_fit_times, "o--", color="g", label="SVC")
95+
feature_number_labels = [str(s) for s in feature_counts]
96+
ax.plot(feature_counts, svc_fit_times, "o--", color="g", label="SVC")
9697
ax.plot(
97-
train_sizes, eig_fit_times, "o-", color="r", label="EigenPro Classifier"
98+
feature_counts, eig_fit_times, "o-", color="r", label="EigenPro Classifier"
9899
)
99100
ax.set_xscale("log")
100101
ax.set_yscale("log", nonposy="clip")
101-
ax.set_xlabel("train size")
102+
ax.set_xlabel("Number of features")
102103
ax.set_ylabel("time (seconds)")
103104
ax.legend()
104105
ax.set_title("Training Time")
105-
ax.set_xticks(train_sizes)
106-
ax.set_xticklabels(train_size_labels)
106+
ax.set_xticks(feature_counts)
107+
ax.set_xticklabels(feature_number_labels)
107108
ax.set_xticks([], minor=True)
108109
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
109110

110111
# Graph prediction(test) time
111112
ax = plt.subplot2grid((2, 2), (0, 1), rowspan=1)
112-
ax.plot(train_sizes, eig_pred_times, "o-", color="r")
113-
ax.plot(train_sizes, svc_pred_times, "o--", color="g")
113+
ax.plot(feature_counts, eig_pred_times, "o-", color="r")
114+
ax.plot(feature_counts, svc_pred_times, "o--", color="g")
114115
ax.set_xscale("log")
115116
ax.set_yscale("log", nonposy="clip")
116117
ax.set_ylabel("time (seconds)")
@@ -120,13 +121,13 @@
120121

121122
# Graph training error
122123
ax = plt.subplot2grid((2, 2), (1, 1), rowspan=1)
123-
ax.plot(train_sizes, eig_err, "o-", color="r")
124-
ax.plot(train_sizes, svc_err, "o-", color="g")
124+
ax.plot(feature_counts, eig_err, "o-", color="r")
125+
ax.plot(feature_counts, svc_err, "o-", color="g")
125126
ax.set_xscale("log")
126-
ax.set_xticks(train_sizes)
127-
ax.set_xticklabels(train_size_labels)
127+
ax.set_xticks(feature_counts)
128+
ax.set_xticklabels(feature_number_labels)
128129
ax.set_xticks([], minor=True)
129-
ax.set_xlabel("train size")
130+
ax.set_xlabel("Number of features")
130131
ax.set_ylabel("Classification error %")
131132
plt.tight_layout()
132133
plt.show()

0 commit comments

Comments
 (0)