Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 8f11fc7

Browse files
authored
Merge pull request #98 from kmike/predict_proba_exception
raise AttributeError if predict_proba is not available
2 parents 7829f92 + daa0569 commit 8f11fc7

File tree

9 files changed

+58
-8
lines changed

9 files changed

+58
-8
lines changed

lightning/impl/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,18 @@ def n_nonzero(self, percentage=False):
3838

3939
class BaseClassifier(BaseEstimator, ClassifierMixin):
4040

41-
def predict_proba(self, X):
41+
@property
42+
def predict_proba(self):
43+
if self.loss not in ("log", "modified_huber"):
44+
raise AttributeError("predict_proba only supported when"
45+
" loss='log' or loss='modified_huber' "
46+
"(%s given)" % self.loss)
47+
return self._predict_proba
48+
49+
def _predict_proba(self, X):
4250
if len(self.classes_) != 2:
43-
raise NotImplementedError("predict_(log_)proba only supported"
51+
raise NotImplementedError("predict_proba only supported"
4452
" for binary classification")
45-
4653
if self.loss == "log":
4754
df = self.decision_function(X).ravel()
4855
prob = 1.0 / (1.0 + np.exp(-df))
@@ -52,7 +59,7 @@ def predict_proba(self, X):
5259
prob += 1
5360
prob /= 2
5461
else:
55-
raise NotImplementedError("predict_(log_)proba only supported when"
62+
raise NotImplementedError("predict_proba only supported when"
5663
" loss='log' or loss='modified_huber' "
5764
"(%s given)" % self.loss)
5865

lightning/impl/tests/test_adagrad.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightning.classification import AdaGradClassifier
88
from lightning.regression import AdaGradRegressor
99
from lightning.impl.adagrad_fast import _proj_elastic_all
10+
from lightning.impl.tests.utils import check_predict_proba
1011

1112
iris = load_iris()
1213
X, y = iris.data, iris.target
@@ -18,13 +19,15 @@
1819
def test_adagrad_elastic_hinge():
1920
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10, random_state=0)
2021
clf.fit(X_bin, y_bin)
22+
assert not hasattr(clf, "predict_proba")
2123
assert_equal(clf.score(X_bin, y_bin), 1.0)
2224

2325

2426
def test_adagrad_elastic_smooth_hinge():
2527
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, loss="smooth_hinge",
2628
n_iter=10, random_state=0)
2729
clf.fit(X_bin, y_bin)
30+
assert not hasattr(clf, "predict_proba")
2831
assert_equal(clf.score(X_bin, y_bin), 1.0)
2932

3033

@@ -33,11 +36,13 @@ def test_adagrad_elastic_log():
3336
random_state=0)
3437
clf.fit(X_bin, y_bin)
3538
assert_equal(clf.score(X_bin, y_bin), 1.0)
39+
check_predict_proba(clf, X_bin)
3640

3741

3842
def test_adagrad_hinge_multiclass():
3943
clf = AdaGradClassifier(alpha=1e-2, n_iter=100, loss="hinge", random_state=0)
4044
clf.fit(X, y)
45+
assert not hasattr(clf, "predict_proba")
4146
assert_almost_equal(clf.score(X, y), 0.960, 3)
4247

4348

lightning/impl/tests/test_dual_cd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from sklearn.externals.six.moves import xrange
77

88
from sklearn.utils.testing import assert_equal
9-
from sklearn.utils.testing import assert_almost_equal
109
from sklearn.utils.testing import assert_greater
1110
from sklearn.utils.testing import assert_array_almost_equal
1211

lightning/impl/tests/test_primal_cd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from lightning.impl.datasets.samples_generator import make_classification
1717
from lightning.impl.primal_cd import CDClassifier, CDRegressor
18+
from lightning.impl.tests.utils import check_predict_proba
19+
1820

1921
bin_dense, bin_target = make_classification(n_samples=200, n_features=100,
2022
n_informative=5,
@@ -31,6 +33,7 @@
3133
def test_fit_linear_binary_l1r():
3234
clf = CDClassifier(C=1.0, random_state=0, penalty="l1")
3335
clf.fit(bin_dense, bin_target)
36+
assert not hasattr(clf, 'predict_proba')
3437
acc = clf.score(bin_dense, bin_target)
3538
assert_almost_equal(acc, 1.0)
3639
n_nz = clf.n_nonzero()
@@ -51,6 +54,7 @@ def test_fit_linear_binary_l1r():
5154
def test_fit_linear_binary_l1r_smooth_hinge():
5255
clf = CDClassifier(C=1.0, loss="smooth_hinge", random_state=0, penalty="l1")
5356
clf.fit(bin_dense, bin_target)
57+
assert not hasattr(clf, 'predict_proba')
5458
acc = clf.score(bin_dense, bin_target)
5559
assert_almost_equal(acc, 1.0)
5660

@@ -102,6 +106,7 @@ def test_warm_start_l1r_regression():
102106
def test_fit_linear_binary_l1r_log_loss():
103107
clf = CDClassifier(C=1.0, random_state=0, penalty="l1", loss="log")
104108
clf.fit(bin_dense, bin_target)
109+
check_predict_proba(clf, bin_dense)
105110
acc = clf.score(bin_dense, bin_target)
106111
assert_almost_equal(acc, 0.995)
107112

@@ -133,6 +138,7 @@ def test_fit_linear_binary_l2r_modified_huber():
133138
clf = CDClassifier(C=1.0, random_state=0, penalty="l2",
134139
loss="modified_huber")
135140
clf.fit(bin_dense, bin_target)
141+
check_predict_proba(clf, bin_dense)
136142
acc = clf.score(bin_dense, bin_target)
137143
assert_almost_equal(acc, 1.0)
138144

lightning/impl/tests/test_sag.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightning.impl.sgd_fast import Log
2020
from lightning.impl.sgd_fast import SquaredLoss
2121
from lightning.impl.sag import get_auto_step_size
22+
from lightning.impl.tests.utils import check_predict_proba
2223

2324

2425
iris = load_iris()
@@ -211,6 +212,7 @@ def test_sag():
211212
PySAGClassifier(eta=1e-3, max_iter=20, random_state=0)
212213
):
213214
clf.fit(X_bin, y_bin)
215+
assert not hasattr(clf, 'predict_proba')
214216
assert_equal(clf.score(X_bin, y_bin), 1.0)
215217
assert_equal(list(clf.classes_), [-1, 1])
216218

@@ -244,8 +246,7 @@ def test_sag_proba():
244246
sag = SAGClassifier(eta=1e-3, alpha=0.0, beta=0.0, max_iter=10,
245247
loss='log', random_state=0)
246248
sag.fit(X, y)
247-
probas = sag.predict_proba(X)
248-
assert_equal(probas.sum(), n_samples)
249+
check_predict_proba(sag, X)
249250

250251

251252
def test_sag_multiclass_classes():

lightning/impl/tests/test_sdca.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from lightning.classification import SDCAClassifier
88
from lightning.regression import SDCARegressor
9+
from lightning.impl.tests.utils import check_predict_proba
10+
911

1012
iris = load_iris()
1113
X, y = iris.data, iris.target
@@ -17,6 +19,7 @@
1719
def test_sdca_hinge():
1820
clf = SDCAClassifier(loss="hinge", random_state=0)
1921
clf.fit(X_bin, y_bin)
22+
assert not hasattr(clf, 'predict_proba')
2023
assert_equal(clf.score(X_bin, y_bin), 1.0)
2124

2225

@@ -30,12 +33,14 @@ def test_sdca_hinge_multiclass():
3033
def test_sdca_squared():
3134
clf = SDCAClassifier(loss="squared", random_state=0)
3235
clf.fit(X_bin, y_bin)
36+
assert not hasattr(clf, 'predict_proba')
3337
assert_equal(clf.score(X_bin, y_bin), 1.0)
3438

3539

3640
def test_sdca_absolute():
3741
clf = SDCAClassifier(loss="absolute", random_state=0)
3842
clf.fit(X_bin, y_bin)
43+
assert not hasattr(clf, 'predict_proba')
3944
assert_equal(clf.score(X_bin, y_bin), 1.0)
4045

4146

@@ -50,6 +55,7 @@ def test_sdca_smooth_hinge_elastic():
5055
clf = SDCAClassifier(alpha=0.5, l1_ratio=0.85, loss="smooth_hinge",
5156
random_state=0)
5257
clf.fit(X_bin, y_bin)
58+
assert not hasattr(clf, 'predict_proba')
5359
assert_equal(clf.score(X_bin, y_bin), 1.0)
5460

5561

lightning/impl/tests/test_sgd.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from lightning.impl.datasets.samples_generator import make_nn_regression
1313
from lightning.impl.sgd import SGDClassifier
1414
from lightning.impl.sgd import SGDRegressor
15+
from lightning.impl.tests.utils import check_predict_proba
1516

1617

1718
bin_dense, bin_target = make_classification(n_samples=200, n_features=100,
@@ -25,6 +26,7 @@
2526
bin_csr = sp.csr_matrix(bin_dense)
2627
mult_csr = sp.csr_matrix(mult_dense)
2728

29+
2830
def test_binary_linear_sgd():
2931
for data in (bin_dense, bin_csr):
3032
for clf in (SGDClassifier(random_state=0, loss="hinge",
@@ -43,10 +45,13 @@ def test_binary_linear_sgd():
4345
SGDClassifier(random_state=0, loss="modified_huber",
4446
fit_intercept=True, learning_rate="constant"),
4547
):
46-
4748
clf.fit(data, bin_target)
4849
assert_greater(clf.score(data, bin_target), 0.934)
4950
assert_equal(list(clf.classes_), [0, 1])
51+
if clf.loss in ('log', 'modified_huber'):
52+
check_predict_proba(clf, data)
53+
else:
54+
assert not hasattr(clf, 'predict_proba')
5055

5156

5257
def test_multiclass_sgd():

lightning/impl/tests/test_svrg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
def test_svrg():
2020
clf = SVRGClassifier(eta=1e-3, max_iter=20, random_state=0, verbose=0)
2121
clf.fit(X_bin, y_bin)
22+
assert not hasattr(clf, 'predict_proba')
2223
assert_equal(clf.score(X_bin, y_bin), 1.0)
2324

2425

lightning/impl/tests/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
from sklearn.utils.testing import assert_array_equal, assert_equal
3+
4+
5+
def check_predict_proba(clf, X):
6+
y_pred = clf.predict(X)
7+
n_samples = y_pred.shape[0]
8+
# normalize negative class to 0 (it is sometimes 0, sometimes -1)
9+
y_pred = (y_pred == 1)
10+
11+
# check that predict_proba result agree with y_true
12+
y_proba = clf.predict_proba(X)
13+
assert_equal(y_proba.shape, (n_samples, 2))
14+
y_proba_best = (y_proba.argmax(axis=1) == 1)
15+
assert_array_equal(y_proba_best, y_pred)
16+
17+
# check that y_proba looks like probability
18+
assert not (y_proba > 1).any()
19+
assert not (y_proba < 0).any()
20+
assert_array_equal(y_proba.sum(axis=1), 1.0)

0 commit comments

Comments
 (0)