Skip to content

Commit 81c1c4a

Browse files
committed
Rewrite example/classifier_comparison.rb
1 parent c2be235 commit 81c1c4a

File tree

1 file changed

+52
-52
lines changed

1 file changed

+52
-52
lines changed

examples/classifier_comparison.rb

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -34,102 +34,102 @@
3434
]
3535

3636
classifiers = [
37-
KNeighborsClassifier.(3),
38-
SVC.(kernel: 'linear', C: 0.025),
39-
SVC.(gamma: 2, C: 1),
40-
DecisionTreeClassifier.(max_depth: 5),
41-
RandomForestClassifier.(max_depth: 5, n_estimators: 10, max_features: 1),
42-
AdaBoostClassifier.(),
43-
GaussianNB.(),
44-
LinearDiscriminantAnalysis.(),
45-
QuadraticDiscriminantAnalysis.()
37+
KNeighborsClassifier.new(3),
38+
SVC.new(kernel: 'linear', C: 0.025),
39+
SVC.new(gamma: 2, C: 1),
40+
DecisionTreeClassifier.new(max_depth: 5),
41+
RandomForestClassifier.new(max_depth: 5, n_estimators: 10, max_features: 1),
42+
AdaBoostClassifier.new(),
43+
GaussianNB.new(),
44+
LinearDiscriminantAnalysis.new(),
45+
QuadraticDiscriminantAnalysis.new()
4646
]
4747

48-
x, y = make_classification.(
48+
x, y = *make_classification(
4949
n_features: 2,
5050
n_redundant: 0,
5151
n_informative: 2,
5252
random_state: 1,
5353
n_clusters_per_class: 1
5454
)
5555

56-
np.random.seed.(42)
57-
x += 2 * np.random.random_sample.(x.shape)
58-
linearly_separable = PyCall.tuple(x, y)
56+
np.random.seed(42)
57+
x += 2 * np.random.random_sample(x.shape)
58+
linearly_separable = PyCall.tuple([x, y]) # FIXME: allow PyCall.tuple(x, y)
5959

6060
datasets = [
61-
make_moons.(noise: 0.3, random_state: 0),
62-
make_circles.(noise: 0.2, factor: 0.5, random_state: 1),
61+
make_moons(noise: 0.3, random_state: 0),
62+
make_circles(noise: 0.2, factor: 0.5, random_state: 1),
6363
linearly_separable
6464
]
6565

66-
fig = plt.figure.(figsize: PyCall.tuple(27, 9))
66+
fig = plt.figure(figsize: [27, 9])
6767
i = 1
68-
all = PyCall.slice(nil)
68+
all = PyCall::Slice.all
6969
datasets.each do |ds|
70-
x, y = ds
71-
x = StandardScaler.().fit_transform.(x)
72-
x_train, x_test, y_train, y_test = train_test_split.(x, y, test_size: 0.4)
70+
x, y = *ds
71+
x = StandardScaler.new.fit_transform(x)
72+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size: 0.4)
7373

74-
x_min, x_max = np.min.(x[all, 0]) - 0.5, np.max.(x[all, 0]) + 0.5
75-
y_min, y_max = np.min.(x[all, 1]) - 0.5, np.max.(x[all, 1]) + 0.5
74+
x_min, x_max = np.min(x[all, 0]) - 0.5, np.max(x[all, 0]) + 0.5
75+
y_min, y_max = np.min(x[all, 1]) - 0.5, np.max(x[all, 1]) + 0.5
7676

77-
xx, yy = np.meshgrid.(
78-
np.linspace.(x_min, x_max, ((x_max - x_min)/h).round),
79-
np.linspace.(y_min, y_max, ((y_max - y_min)/h).round),
77+
xx, yy = np.meshgrid(
78+
np.linspace(x_min, x_max, ((x_max - x_min)/h).round),
79+
np.linspace(y_min, y_max, ((y_max - y_min)/h).round),
8080
)
81-
mesh_points = np.dstack.(PyCall.tuple(xx.ravel.(), yy.ravel.()))[0, all, all]
81+
mesh_points = np.dstack(PyCall.tuple([xx.ravel(), yy.ravel()]))[0, all, all]
8282

8383
# just plot the dataset first
84-
cm = plt.cm.RdBu
85-
cm_bright = mplc.ListedColormap.(["#FF0000", "#0000FF"])
86-
ax = plt.subplot.(datasets.length, classifiers.length + 1, i)
84+
cm = plt.cm.__dict__[:RdBu]
85+
cm_bright = mplc.ListedColormap.new(["#FF0000", "#0000FF"])
86+
ax = plt.subplot(datasets.length, classifiers.length + 1, i)
8787
# plot the training points
88-
ax.scatter.(x_train[all, 0], x_train[all, 1], c: y_train, cmap: cm_bright)
88+
ax.scatter(x_train[all, 0], x_train[all, 1], c: y_train, cmap: cm_bright)
8989
# and testing points
90-
ax.scatter.(x_test[all, 0], x_test[all, 1], c: y_test, cmap: cm_bright, alpha: 0.6)
90+
ax.scatter(x_test[all, 0], x_test[all, 1], c: y_test, cmap: cm_bright, alpha: 0.6)
9191

92-
ax.set_xlim.(np.min.(xx), np.max.(xx))
93-
ax.set_ylim.(np.min.(yy), np.max.(yy))
94-
ax.set_xticks.(PyCall.tuple())
95-
ax.set_yticks.(PyCall.tuple())
92+
ax.set_xlim(np.min(xx), np.max(xx))
93+
ax.set_ylim(np.min(yy), np.max(yy))
94+
ax.set_xticks(PyCall.tuple())
95+
ax.set_yticks(PyCall.tuple())
9696
i += 1
9797

9898
# iterate over classifiers
9999
names.zip(classifiers).each do |name, clf|
100-
ax = plt.subplot.(datasets.length, classifiers.length + 1, i)
101-
clf.fit.(x_train, y_train)
102-
scor = clf.score.(x_test, y_test)
100+
ax = plt.subplot(datasets.length, classifiers.length + 1, i)
101+
clf.fit(x_train, y_train)
102+
scor = clf.score(x_test, y_test)
103103

104104
# Plot the decision boundary. For that, we will assign a color to each
105105
# point in the mesh [x_min, x_max]x[y_min, y_max]
106106
begin
107107
# not implemented for some
108-
z = clf.decision_function.(mesh_points)
108+
z = clf.decision_function(mesh_points)
109109
rescue
110-
z = clf.predict_proba.(mesh_points)[all, 1]
110+
z = clf.predict_proba(mesh_points)[all, 1]
111111
end
112112

113113
# Put the result into a color plot
114-
z = z.reshape.(xx.shape)
115-
ax.contourf.(xx, yy, z, cmap: cm, alpha: 0.8)
114+
z = z.reshape(xx.shape)
115+
ax.contourf(xx, yy, z, cmap: cm, alpha: 0.8)
116116

117117
# Plot also the training points
118-
ax.scatter.(x_train[all, 0], x_train[all, 1], c: y_train, cmap: cm_bright)
118+
ax.scatter(x_train[all, 0], x_train[all, 1], c: y_train, cmap: cm_bright)
119119
# and testing points
120-
ax.scatter.(x_test[all, 0], x_test[all, 1], c: y_test, cmap: cm_bright, alpha: 0.6)
120+
ax.scatter(x_test[all, 0], x_test[all, 1], c: y_test, cmap: cm_bright, alpha: 0.6)
121121

122-
ax.set_xlim.(np.min.(xx), np.max.(xx))
123-
ax.set_ylim.(np.min.(yy), np.max.(yy))
124-
ax.set_xticks.(PyCall.tuple())
125-
ax.set_yticks.(PyCall.tuple())
126-
ax.set_title.(name)
122+
ax.set_xlim(np.min(xx), np.max(xx))
123+
ax.set_ylim(np.min(yy), np.max(yy))
124+
ax.set_xticks(PyCall.tuple())
125+
ax.set_yticks(PyCall.tuple())
126+
ax.set_title(name)
127127

128-
ax.text.(np.max.(xx) - 0.3, np.min.(yy) + 0.3, "%.2f" % scor, size: 15, horizontalalignment: 'right')
128+
ax.text(np.max(xx) - 0.3, np.min(yy) + 0.3, "%.2f" % scor, size: 15, horizontalalignment: 'right')
129129

130130
i += 1
131131
end
132132
end
133133

134-
fig.subplots_adjust.(left: 0.02, right: 0.98)
135-
plt.show.()
134+
fig.subplots_adjust(left: 0.02, right: 0.98)
135+
plt.show()

0 commit comments

Comments
 (0)