Skip to content

Commit 89e6b39

Browse files
change to match new scikit-learn losses and deprecated function to available_if
1 parent a857f67 commit 89e6b39

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

sklearn_extra/robust/robust_weighted_estimator.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sklearn.cluster import MiniBatchKMeans
2727
from sklearn.metrics.pairwise import euclidean_distances
2828
from sklearn.exceptions import ConvergenceWarning
29-
from sklearn.utils.metaestimators import if_delegate_has_method
29+
from sklearn.utils.metaestimators import available_if
3030

3131
# Tool library in which we get robust mean estimators.
3232
from .mean_estimators import median_of_means_blocked, block_mom, huber
@@ -48,7 +48,7 @@
4848

4949
LOSS_FUNCTIONS = {
5050
"hinge": (Hinge,),
51-
"log": (Log,),
51+
"log_loss": (Log,),
5252
"squared_error": (SquaredLoss,),
5353
"squared_loss": (SquaredLoss,),
5454
"squared_hinge": (SquaredHinge,),
@@ -114,8 +114,8 @@ class _RobustWeightedEstimator(BaseEstimator):
114114
loss : string or callable, mandatory
115115
Name of the loss used, must be the same loss as the one optimized in
116116
base_estimator.
117-
Classification losses supported : 'log', 'hinge', 'squared_hinge',
118-
'modified_huber'. If 'log', then the base_estimator must support
117+
Classification losses supported : 'log_loss', 'hinge', 'squared_hinge',
118+
'modified_huber'. If 'log_loss', then the base_estimator must support
119119
predict_proba. Regression losses supported : 'squared_error', 'huber'.
120120
If callable, the function is used as loss function ro construct
121121
the weights.
@@ -501,7 +501,7 @@ def predict(self, X):
501501
return self.base_estimator_.predict(X)
502502

503503
def _check_proba(self):
504-
if self.loss != "log":
504+
if self.loss != "log_loss":
505505
raise AttributeError(
506506
"Probability estimates are not available for"
507507
" loss=%r" % self.loss
@@ -538,7 +538,14 @@ def score(self, X, y=None):
538538
check_is_fitted(self, attributes=["base_estimator_"])
539539
return self.base_estimator_.score(X, y)
540540

541-
@if_delegate_has_method(delegate="base_estimator")
541+
542+
def _estimator_has(attr):
543+
def check(self):
544+
return hasattr(self.base_estimator_, attr)
545+
546+
return check
547+
548+
@available_if(_estimator_has("decision_function"))
542549
def decision_function(self, X):
543550
"""Predict using the linear model. For classifiers only.
544551
@@ -607,7 +614,7 @@ class RobustWeightedClassifier(BaseEstimator, ClassifierMixin):
607614
(using the inter-quartile range), this tends to be conservative
608615
(robust).
609616
610-
loss : string, None or callable, default="log"
617+
loss : string, None or callable, default="log_loss"
611618
Classification losses supported : 'log', 'hinge', 'modified_huber'.
612619
If 'log', then the base_estimator must support predict_proba.
613620
@@ -709,7 +716,7 @@ def __init__(
709716
max_iter=100,
710717
c=None,
711718
k=0,
712-
loss="log",
719+
loss="log_loss",
713720
sgd_args=None,
714721
multi_class="ovr",
715722
n_jobs=1,
@@ -809,7 +816,7 @@ def predict(self, X):
809816
return self.base_estimator_.predict(X)
810817

811818
def _check_proba(self):
812-
if self.loss != "log":
819+
if self.loss != "log_loss":
813820
raise AttributeError(
814821
"Probability estimates are not available for"
815822
" loss=%r" % self.loss

sklearn_extra/robust/tests/test_robust_weighted_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
X_cc[f] = [10, 5] + rng.normal(size=2) * 0.1
3939
y_cc[f] = 0
4040

41-
classif_losses = ["log", "hinge"]
41+
classif_losses = ["log_loss", "hinge"]
4242
weightings = ["huber", "mom"]
4343
multi_class = ["ovr", "ovo"]
4444

@@ -167,7 +167,7 @@ def test_classif_binary(weighting):
167167
multi_class="binary",
168168
random_state=rng,
169169
)
170-
clf_not_rob = SGDClassifier(loss="log", random_state=rng)
170+
clf_not_rob = SGDClassifier(loss="log_loss", random_state=rng)
171171
clf.fit(X_cb, y_cb)
172172
clf_not_rob.fit(X_cb, y_cb)
173173
norm_coef1 = np.linalg.norm(np.hstack([clf.coef_.ravel(), clf.intercept_]))
@@ -201,7 +201,7 @@ def test_classif_corrupted_weights(weighting):
201201
assert np.mean(clf.weights_[:3]) < np.mean(clf.weights_[3:])
202202

203203

204-
# Case "log" loss, test predict_proba
204+
# Case "log_loss" loss, test predict_proba
205205
@pytest.mark.parametrize("weighting", weightings)
206206
def test_predict_proba(weighting):
207207
clf = RobustWeightedClassifier(
@@ -211,7 +211,7 @@ def test_predict_proba(weighting):
211211
c=1e7,
212212
random_state=rng,
213213
)
214-
clf_not_rob = SGDClassifier(loss="log", random_state=rng)
214+
clf_not_rob = SGDClassifier(loss="log_loss", random_state=rng)
215215
clf.fit(X_c, y_c)
216216
clf_not_rob.fit(X_c, y_c)
217217
pred1 = clf.base_estimator_.predict_proba(X_c)[:, 1]

0 commit comments

Comments
 (0)