Skip to content

Commit 9523006

Browse files
snath-xocogriseljeremiedbb
authored
Fix linear svc handling sample weights under class_weight="balanced" (scikit-learn#30057)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent c565029 commit 9523006

File tree

8 files changed

+84
-15
lines changed

8 files changed

+84
-15
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- :class:`linear_model.LogisticRegression` and
2+
:class:`linear_model.LogisticRegressionCV` now properly pass sample weights to
3+
:func:`utils.class_weight.compute_class_weight` when fit with
4+
`class_weight="balanced"`.
5+
By :user:`Shruti Nath <snath-xoc>` and :user:`Olivier Grisel <ogrisel>`
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`svm.LinearSVC` now properly passes sample weights to
2+
:func:`utils.class_weight.compute_class_weight` when fit with
3+
`class_weight="balanced"`.
4+
By :user:`Shruti Nath <snath-xoc>`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`utils.class_weight.compute_class_weight` now properly accounts for
2+
sample weights when using strategy "balanced" to calculate class weights.
3+
By :user:`Shruti Nath <snath-xoc>`

sklearn/linear_model/_logistic.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ def _logistic_regression_path(
305305
if isinstance(class_weight, dict) or (
306306
multi_class == "multinomial" and class_weight is not None
307307
):
308-
class_weight_ = compute_class_weight(class_weight, classes=classes, y=y)
308+
class_weight_ = compute_class_weight(
309+
class_weight, classes=classes, y=y, sample_weight=sample_weight
310+
)
309311
sample_weight *= class_weight_[le.fit_transform(y)]
310312

311313
# For doing a ovr, we need to mask the labels first. For the
@@ -326,7 +328,10 @@ def _logistic_regression_path(
326328
# for compute_class_weight
327329
if class_weight == "balanced":
328330
class_weight_ = compute_class_weight(
329-
class_weight, classes=mask_classes, y=y_bin
331+
class_weight,
332+
classes=mask_classes,
333+
y=y_bin,
334+
sample_weight=sample_weight,
330335
)
331336
sample_weight *= class_weight_[le.fit_transform(y_bin)]
332337

@@ -1981,7 +1986,10 @@ def fit(self, X, y, sample_weight=None, **params):
19811986
# compute the class weights for the entire dataset y
19821987
if class_weight == "balanced":
19831988
class_weight = compute_class_weight(
1984-
class_weight, classes=np.arange(len(self.classes_)), y=y
1989+
class_weight,
1990+
classes=np.arange(len(self.classes_)),
1991+
y=y,
1992+
sample_weight=sample_weight,
19851993
)
19861994
class_weight = dict(enumerate(class_weight))
19871995

sklearn/svm/_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,8 +1189,9 @@ def _fit_liblinear(
11891189
" in the data, but the data contains only one"
11901190
" class: %r" % classes_[0]
11911191
)
1192-
1193-
class_weight_ = compute_class_weight(class_weight, classes=classes_, y=y)
1192+
class_weight_ = compute_class_weight(
1193+
class_weight, classes=classes_, y=y, sample_weight=sample_weight
1194+
)
11941195
else:
11951196
class_weight_ = np.empty(0, dtype=np.float64)
11961197
y_ind = y

sklearn/utils/_test_common/instance_generator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,17 @@
600600
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
601601
},
602602
LinearDiscriminantAnalysis: {"check_dict_unchanged": dict(n_components=1)},
603+
LinearSVC: {
604+
"check_sample_weight_equivalence": [
605+
# TODO: dual=True is a stochastic solver: we cannot rely on
606+
# check_sample_weight_equivalence to check the correct handling of
607+
# sample_weight and we would need a statistical test instead, see
608+
# meta-issue #162298.
609+
# dict(max_iter=20, dual=True, tol=1e-12),
610+
dict(dual=False, tol=1e-12),
611+
dict(dual=False, tol=1e-12, class_weight="balanced"),
612+
]
613+
},
603614
LinearRegression: {
604615
"check_estimator_sparse_tag": [dict(positive=False), dict(positive=True)],
605616
"check_sample_weight_equivalence_on_dense_data": [
@@ -615,6 +626,14 @@
615626
dict(solver="liblinear"),
616627
dict(solver="newton-cg"),
617628
dict(solver="newton-cholesky"),
629+
dict(solver="newton-cholesky", class_weight="balanced"),
630+
]
631+
},
632+
LogisticRegressionCV: {
633+
"check_sample_weight_equivalence": [
634+
dict(solver="lbfgs"),
635+
dict(solver="newton-cholesky"),
636+
dict(solver="newton-cholesky", class_weight="balanced"),
618637
],
619638
"check_sample_weight_equivalence_on_sparse_data": [
620639
dict(solver="liblinear"),

sklearn/utils/class_weight.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,27 @@
77
from scipy import sparse
88

99
from ._param_validation import StrOptions, validate_params
10+
from .validation import _check_sample_weight
1011

1112

1213
@validate_params(
1314
{
1415
"class_weight": [dict, StrOptions({"balanced"}), None],
1516
"classes": [np.ndarray],
1617
"y": ["array-like"],
18+
"sample_weight": ["array-like", None],
1719
},
1820
prefer_skip_nested_validation=True,
1921
)
20-
def compute_class_weight(class_weight, *, classes, y):
22+
def compute_class_weight(class_weight, *, classes, y, sample_weight=None):
2123
"""Estimate class weights for unbalanced datasets.
2224
2325
Parameters
2426
----------
2527
class_weight : dict, "balanced" or None
2628
If "balanced", class weights will be given by
27-
`n_samples / (n_classes * np.bincount(y))`.
29+
`n_samples / (n_classes * np.bincount(y))` or their weighted equivalent if
30+
`sample_weight` is provided.
2831
If a dictionary is given, keys are classes and values are corresponding class
2932
weights.
3033
If `None` is given, the class weights will be uniform.
@@ -36,6 +39,10 @@ def compute_class_weight(class_weight, *, classes, y):
3639
y : array-like of shape (n_samples,)
3740
Array of original class labels per sample.
3841
42+
sample_weight : array-like of shape (n_samples,), default=None
43+
Array of weights that are assigned to individual samples. Only used when
44+
`class_weight='balanced'`.
45+
3946
Returns
4047
-------
4148
class_weight_vect : ndarray of shape (n_classes,)
@@ -69,7 +76,11 @@ def compute_class_weight(class_weight, *, classes, y):
6976
if not all(np.isin(classes, le.classes_)):
7077
raise ValueError("classes should have valid labels that are in y")
7178

72-
recip_freq = len(y) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64))
79+
sample_weight = _check_sample_weight(sample_weight, y)
80+
weighted_class_counts = np.bincount(y_ind, weights=sample_weight)
81+
recip_freq = weighted_class_counts.sum() / (
82+
len(le.classes_) * weighted_class_counts
83+
)
7384
weight = recip_freq[le.transform(classes)]
7485
else:
7586
# user-defined dictionary

sklearn/utils/tests/test_class_weight.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,32 @@ def test_compute_class_weight_balanced_negative():
129129
assert len(cw) == len(classes)
130130
assert_array_almost_equal(cw, np.array([1.0, 1.0, 1.0]))
131131

132-
# Test with unbalanced class labels.
133-
y = np.asarray([-1, 0, 0, -2, -2, -2])
134132

135-
cw = compute_class_weight("balanced", classes=classes, y=y)
136-
assert len(cw) == len(classes)
137-
class_counts = np.bincount(y + 2)
138-
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
139-
assert_array_almost_equal(cw, [2.0 / 3, 2.0, 1.0])
133+
def test_compute_class_weight_balanced_sample_weight_equivalence():
134+
# Test with unbalanced and negative class labels for
135+
# equivalence between repeated and weighted samples
136+
137+
classes = np.array([-2, -1, 0])
138+
y = np.asarray([-1, -1, 0, 0, -2, -2])
139+
sw = np.asarray([1, 0, 1, 1, 1, 2])
140+
141+
y_rep = np.repeat(y, sw, axis=0)
142+
143+
class_weights_weighted = compute_class_weight(
144+
"balanced", classes=classes, y=y, sample_weight=sw
145+
)
146+
class_weights_repeated = compute_class_weight("balanced", classes=classes, y=y_rep)
147+
assert len(class_weights_weighted) == len(classes)
148+
assert len(class_weights_repeated) == len(classes)
149+
150+
class_counts_weighted = np.bincount(y + 2, weights=sw)
151+
class_counts_repeated = np.bincount(y_rep + 2)
152+
153+
assert np.dot(class_weights_weighted, class_counts_weighted) == pytest.approx(
154+
np.dot(class_weights_repeated, class_counts_repeated)
155+
)
156+
157+
assert_allclose(class_weights_weighted, class_weights_repeated)
140158

141159

142160
def test_compute_class_weight_balanced_unordered():

0 commit comments

Comments
 (0)