Skip to content

Commit c903687

Browse files
ENH RobustWeightedEstimator : code + examples (#42)
1 parent 10103f8 commit c903687

17 files changed

+1566
-3
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
==================================================================
3+
Plot of accuracy and time as sample_size and num_features increase
4+
==================================================================
5+
We show that the increase in computation time is linear when
6+
increasing the number of features or the sample size increases.
7+
"""
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
from time import time
12+
13+
from sklearn_extra.robust import RobustWeightedEstimator
14+
from sklearn.linear_model import SGDClassifier
15+
from sklearn.datasets import make_classification
16+
from sklearn.model_selection import cross_val_score
17+
18+
rng = np.random.RandomState(42)
19+
20+
dimensions = np.linspace(50, 5000, num=8).astype(int)
21+
sample_sizes = np.linspace(50, 5000, num=8).astype(int)
22+
accuracies = []
23+
times = []
24+
25+
# Get the accuracy and time of computations for a dataset with varying number
26+
# of features
27+
28+
for d in dimensions:
29+
# Make an example in dimension d. Use a scale factor for the problem to be
30+
# easy even in high dimension.
31+
X, y = make_classification(
32+
n_samples=200, n_features=d, scale=1 / np.sqrt(2 * d), random_state=rng
33+
)
34+
stime = time()
35+
clf = RobustWeightedEstimator(
36+
SGDClassifier(loss="hinge", penalty="l1"),
37+
loss="hinge",
38+
random_state=rng,
39+
)
40+
accuracies.append(np.mean(cross_val_score(clf, X, y, cv=10)))
41+
times.append(time() - stime)
42+
43+
fig, axs = plt.subplots(2, 2)
44+
axs[0, 0].plot(dimensions, accuracies)
45+
axs[0, 0].set_xlabel("Number of features")
46+
axs[0, 0].set_ylabel("accuracy")
47+
axs[0, 1].plot(dimensions, times)
48+
axs[0, 1].set_xlabel("Number of features")
49+
axs[0, 1].set_ylabel("Time to fit and predict (s)")
50+
51+
accuracies = []
52+
times = []
53+
54+
# Get the accuracy and time of computations for a dataset with varying number
55+
# of samples
56+
57+
for n in sample_sizes:
58+
X, y = make_classification(n_samples=n, n_features=5, random_state=rng)
59+
stime = time()
60+
clf = RobustWeightedEstimator(
61+
SGDClassifier(loss="hinge", penalty="l1"),
62+
loss="hinge",
63+
random_state=rng,
64+
)
65+
accuracies.append(np.mean(cross_val_score(clf, X, y, cv=10)))
66+
times.append(time() - stime)
67+
68+
axs[1, 0].plot(dimensions, accuracies)
69+
axs[1, 0].set_xlabel("Number of features")
70+
axs[1, 0].set_ylabel("accuracy")
71+
axs[1, 1].plot(dimensions, times)
72+
axs[1, 1].set_xlabel("Number of features")
73+
axs[1, 1].set_ylabel("Time to fit and predict (s)")
74+
75+
76+
plt.show()

doc/api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,13 @@ Clustering
3131
:template: class.rst
3232

3333
cluster.KMedoids
34+
35+
Robust
36+
====================
37+
38+
.. autosummary::
39+
:toctree: generated/
40+
:template: class.rst
41+
42+
robust.RobustWeightedEstimator
3443

doc/images/robust_def_outliers.png

13.1 KB
Loading

doc/images/robust_plot_regression.png

36.3 KB
Loading

doc/modules/robust.rst

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
.. _robust:
2+
3+
===================================================
4+
Robust algorithms for Regression and Classification
5+
===================================================
6+
7+
.. currentmodule:: sklearn_extra.robust
8+
9+
Robust statistics are mostly about how to deal with data corrupted with
10+
outliers (i.e. abnormal data, unique data in some sense).
11+
The aim is to modify classical methods in order to deal with outliers while
12+
loosing as little as possible in efficiency compared to classical (non-robust)
13+
methods applied to non-corrupted datasets.
14+
In particular, in machine learning, we want to bound the
15+
influence that any minority of the dataset can have on the prediction, see
16+
the figure for an example in regression.
17+
18+
.. |robust_regression| image:: ../images/robust_plot_regression.png
19+
:target: ../examples/plot_robust_regression_toy.py
20+
:scale: 70
21+
22+
.. centered:: |robust_regression|
23+
24+
What is an outlier ?
25+
====================
26+
27+
The term "outlier" refers to a discordant minority of the dataset. It is
28+
generally assumed to be a set of points situated outside the bulk of the data
29+
but there exists more complex cases as illustrated in the figure below.
30+
31+
Formally, we define outliers for a given task by considering points for
32+
which the loss function takes unusually high values.
33+
In the case of classification, one can consider that in the following scatter
34+
plot the points in the up-right corner are outliers while the points in the
35+
bottom-left corner are not.
36+
37+
.. |outlier| image:: ../images/robust_def_outliers.png
38+
:scale: 80
39+
40+
.. centered:: |outlier|
41+
42+
Outliers can be caused by a lot of things, among them are human errors, captor
43+
errors or inherent causes. These are often found for example in biology,
44+
econometrics or datasets that describe some human relationships.
45+
46+
Here, we limit ourselves to linear estimators, but non-linear estimators are
47+
also plagued with the same non-robustness properties. See scikit-learn RANSAC
48+
documentation (`scikit-learn <https://scikit-learn.org/stable/modules/linear_model.html#ransac-random-sample-consensus>`__)
49+
for an example of outliers for non-linear estimators.
50+
51+
Robust estimation with robust weighting
52+
=======================================
53+
54+
A lot of learning algorithms are based on a paradigm known as empirical risk
55+
minimization (ERM) which consists in finding the estimator :math:`\widehat{f}`
56+
that minimizes an estimation of the risk.
57+
58+
.. math::
59+
60+
\widehat{f} = \text{argmin}_{f\in F}\frac{1}{n}\sum_{i=1}^n\ell(f(X_i),y_i),
61+
62+
where the :math:`ell` is a loss function (e.g. the squared distance in
63+
regression problems). Said in another way, we are trying to minimize an
64+
estimation of the expected risk and this estimation corresponds to an empirical
65+
mean. However, it is well known that the empirical mean is not robust to
66+
extreme data and these extreme values will have a big influence on the
67+
estimation of :math:`\widehat{f}`. The principle behind the robust weighting
68+
algorithm is to rely on a robust estimator (such as median-of-means (MOM) or
69+
Huber estimator) in place of the empirical mean in the equation above [1]_.
70+
71+
In practice, one can define weights :math:`w_i` that depends on the
72+
:math:`i^{th}` sample, with the weight :math:`w_i` being very small when
73+
the :math:`i^{th}` data is an outlier and large otherwise.
74+
This way, the problem is reduced to the following optimization :
75+
76+
.. math::
77+
78+
\min_{f}\, \frac{1}{n} \sum_{i=1}^n w_i\ell(f(X_i),y_i)
79+
80+
Remark that the weights :math:`w_i` depends on :math:`\widehat{f}`, and the
81+
resulting algorithm is then an alternate optimization scheme, iteratively
82+
doing one step to optimize with respect to :math:`f` while the weights stay
83+
fixed and then one step to estimate the weights while :math:`f` stays fixed.
84+
These two steps are then repeated until convergence.
85+
86+
Robust estimation in practice
87+
=============================
88+
89+
The algorithm
90+
-------------
91+
92+
The approach is implemented as a meta algorithm that takes as input a base
93+
estimator (e.g., SGDClassifier or SGDRegressor). To be compatible, the
94+
base estimator must support partial_fit and sample_weight
95+
partial_fit and sample_weight. Refer to the KMeans example for a template
96+
to adapt the method to other estimators.
97+
98+
At each step, the algorithm estimates sample weights that are meant to be small
99+
for outliers and large for inliers and then we do one optimization step using
100+
the base_estimator optimization algorithm.
101+
102+
There are two weighting scheme supported in this algorithm: Huber-like weights
103+
and median-of-means weights. These two types of weights both come with a
104+
parameter that will determine the robustness/efficiency trade-off of the
105+
estimation.
106+
107+
* Huber weights : the parameter "c" is a positive real number. For small
108+
values of c the estimator is more robust but less efficient than it is
109+
for large values of c.
110+
A good heuristic consists in choosing c as an estimate of the standard
111+
deviation of the losses of the inliers. In practice, if c=None, it is
112+
estimated with the inter-quartile range.
113+
114+
* Median-of-means weights : the parameter "k" is a non-negative integer,
115+
when k=0 the estimator is exactly the same as base_estimator and when
116+
k=sample_size/2 the estimator is very robust but less efficient on inliers.
117+
A good heuristic consists in choosing k as an estimate of the number of
118+
outliers. In practice, if k=None, it is estimated using the number of
119+
points distant from the median of more than a 1.45 times the inter-quartile
120+
range.
121+
122+
.. table:: Robustness/Efficiency tradeoff and choice of parameters
123+
:widths: auto
124+
:align: center
125+
126+
+-----------+----------------------+-----------------+-------------------+
127+
| weighting | Robustness parameter | Small parameter | Large parameter |
128+
+===========+======================+=================+===================+
129+
| mom | k | Non robust | Robust |
130+
+-----------+----------------------+-----------------+-------------------+
131+
| huber | c | Robust | Non robust |
132+
+-----------+----------------------+-----------------+-------------------+
133+
134+
135+
The choice of the optimization parameters max_iter and eta_0 are also very
136+
important for the efficiency of this estimator. It is recommended to use
137+
`cross-validation <https://scikit-learn.org/stable/modules/cross_validation.html>`__
138+
to fix these hyper-parameters. Choosing eta0 too large
139+
can have the effect of making the estimator non-robust. One should also take
140+
care that it can be important to rescale the data (the same way as it is
141+
important to do it for SGD). In the context of a corrupted dataset, please use
142+
`RobustScaler <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html>`__.
143+
144+
This algorithm has been studied in the context of "mom" weights in the
145+
article [1]_, the context of "huber" weights has been mentioned in [2]_.
146+
Both weighting scheme can be seen as special cases of the algorithm in [3]_.
147+
148+
Comparison with other robust estimators
149+
---------------------------------------
150+
151+
There are already some robust algorithms in scikit-learn but one major
152+
difference is that robust algorithms in scikit-learn are primarily meant for
153+
Regression, see `robustness in regression <https://scikit-learn.org/stable/modules/linear_model.html#robustness-regression-outliers-and-modeling-errors>`__.
154+
Hence, we will not talk about classification algorithms in this comparison.
155+
156+
As such we only compare ourselves to TheilSenRegressor and RANSACRegressor as
157+
they both deal with outliers in X and in Y and are closer to
158+
RobustWeightedEstimator.
159+
160+
**Warning:** Huber weights used in our algorithm should not be confused with
161+
HuberRegressor or other regression with “robust losses”. Those types of
162+
regressions are robust only to outliers in the label Y but not in X.
163+
164+
Pro: RANSACRegressor and TheilSenRegressor both use a hard rejection of
165+
outlier. This can be interpreted as though there was an outlier detection
166+
step and then a regression step whereas RobustWeightedEstimator is directly
167+
robust to outliers. This often increase the performance on moderatly corrupted
168+
datasets.
169+
170+
Con: In general, this algorithm is slower than both TheilSenRegressor and
171+
RANSACRegressor.
172+
173+
One other advantage of RobustWeightedEstimator is that it can be used for a
174+
broad range of algorithms. For example, one can do robust unsupervised
175+
learning with RobustWeightedEstimator, see the example using KMeans algorithm.
176+
177+
Speed and limits of the algorithm
178+
---------------------------------
179+
180+
Most of the time, it is interesting to do robust statistics only when there
181+
are outliers and notice that a lot of dataset have previously been "cleaned"
182+
of an outliers in which case this algorithm is not better than base_estimator.
183+
184+
In high dimension, the algorithm is expected to be as good
185+
(or as bad) as base_estimator do in high dimension.
186+
187+
Complexity and limitation:
188+
189+
* weighting=”huber”: the complexity is larger than that of base_estimator but
190+
it is still of the same order of magnitude.
191+
* weighting=”mom”: the larger k is the faster the algorithm will perform if
192+
sample_size is large. This weighting scheme is advised only with
193+
sufficiently large dataset (thumb rule sample_size > 500 the specifics
194+
depend on the dataset).
195+
196+
**Warning:** On a real dataset, one should be aware that there can be outliers
197+
in the training set but also in the test set when the loss is not bounded. See
198+
the example with California housing real dataset, for further discussion.
199+
200+
.. topic:: References:
201+
202+
.. [1] Guillaume Lecué, Matthieu Lerasle and Timothée Mathieu.
203+
`"Robust classification via MOM minimization" <https://doi.org/10.1007/s10994-019-05863-6>`_, Machine Learning Journal (2020).
204+
205+
206+
.. [2] Christian Brownlees, Emilien Joly and Gábor Lugosi.
207+
`"Empirical risk minimization for heavy-tailed losses" <https://projecteuclid.org/euclid.aos/1444222083>`_, Ann. Statist.
208+
Volume 43, Number 6 (2015), 2507-2536.
209+
210+
.. [3] Stanislav Minsker and Timothée Mathieu.
211+
`"Excess risk bounds in robust empirical risk minimization" <https://arxiv.org/abs/1910.07485>`_
212+
arXiv preprint (2019). arXiv:1910.07485.

doc/user_guide.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ User guide
1111
:numbered:
1212

1313
modules/eigenpro.rst
14+
modules/robust.rst
1415

1516
.. _k_medoids:
1617

@@ -63,8 +64,7 @@ This version works as follows:
6364
maximum number of iterations ``max_iter`` is reached.
6465

6566
.. topic:: References:
66-
6767
* Maranzana, F.E., 1963. On the location of supply points to minimize
6868
transportation costs. IBM Systems Journal, 2(2), pp.129-135.
6969
* Park, H.S. and Jun, C.H., 2009. A simple and fast algorithm for K-medoids
70-
clustering. Expert systems with applications, 36(2), pp.3336-3341.
70+
clustering. Expert systems with applications, 36(2), pp.3336-3341.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
======================================================================
4+
A demo of Robust Classification on real dataset "diabetes" from OpenML
5+
======================================================================
6+
In this example we compare the RobustWeightedEstimator using SGDClassifier
7+
for classification on the real dataset "diabetes".
8+
WARNING: running this example can take some time (<1hour).
9+
We only compare the estimator with SGDClassifier as there is no robust
10+
classification estimator in scikit-learn.
11+
"""
12+
import matplotlib.pyplot as plt
13+
import numpy as np
14+
from sklearn_extra.robust import RobustWeightedEstimator
15+
from sklearn.linear_model import SGDClassifier
16+
from sklearn.datasets import fetch_openml
17+
from sklearn.metrics import roc_auc_score, make_scorer
18+
from sklearn.model_selection import cross_val_score
19+
from sklearn.preprocessing import RobustScaler
20+
21+
22+
X, y = fetch_openml(name="diabetes", return_X_y=True)
23+
24+
# replace the label names with 0 or 1
25+
y = (y == "tested_positive").astype(int)
26+
27+
# Scale the dataset with sklearn RobustScaler (important for this algorithm)
28+
X = RobustScaler().fit_transform(X)
29+
30+
# Using GridSearchCV, to tune the parameters alpha, eta0, learning_rate, loss
31+
# and average of SGDClassifier, we get the following parameters.
32+
33+
clf_not_rob = SGDClassifier(average=10, learning_rate="optimal", loss="hinge")
34+
35+
# Then, we use this estimator as base_estimator of RobustWeightedEstimator.
36+
# Using GridSearchCV, we tuned the parameters c and eta0, with the
37+
# choice of "huber" weighting because the sample_size is not very large.
38+
39+
clf_rob = RobustWeightedEstimator(
40+
SGDClassifier(average=10, learning_rate="optimal", loss="hinge"),
41+
weighting="huber",
42+
loss="hinge",
43+
c=1.35,
44+
eta0=1e-3,
45+
max_iter=300,
46+
)
47+
48+
# We compute M times the cross validations in order to also have an estimate
49+
# of the variance of the loss of the estimators.
50+
M = 10
51+
res = []
52+
for f in range(M):
53+
print("\r Progress: %s / %s" % (f + 1, M), end="")
54+
clf = SGDClassifier(average=10, learning_rate="optimal", loss="hinge")
55+
56+
cv_not_rob = cross_val_score(
57+
clf_not_rob, X, y, cv=10, scoring=make_scorer(roc_auc_score)
58+
)
59+
60+
cv_rob = cross_val_score(
61+
clf_rob, X, y, cv=10, scoring=make_scorer(roc_auc_score)
62+
)
63+
64+
res += [[np.mean(cv_rob), np.mean(cv_not_rob)]]
65+
66+
67+
plt.boxplot(np.array(res), labels=["RobustWeightedEstimator", "SGDClassifier"])
68+
plt.ylabel("AUC")
69+
70+
plt.show()
71+
72+
# Remark : when using accuracy score, the optimal hyperparameters change and
73+
# for example the parameter c changes from 1.35 to 10.

0 commit comments

Comments
 (0)