Skip to content

Commit 754d32d

Browse files
authored
[enh] add array API enabled compute_class_weight function for use in classifier estimators (#2697)
* Create class_weight.py * Update class_weight.py * formatting * Update class_weight.py * Update class_weight.py * formatting * another fix * Update class_weight.py * add testing * swap skipif * fix imports * fix signature * fix issue * fix kwargs * fixes for sklearn 1.6 * Update test_class_weight.py * remove hashable dependence * change get_namespace call * fixes to unique checks * fix another mistake * Update test_class_weight.py * Update class_weight.py * Update class_weight.py * make review changes * swap some testing * change to a generator * Update test_class_weight.py * Update class_weight.py * formatting * fix len interaction with array_api_strict * make check for sklearn 1.7 in testing * another fix
1 parent a75e412 commit 754d32d

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

sklearnex/utils/class_weight.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# ==============================================================================
2+
# Copyright contributors to the oneDAL Project
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from sklearn.preprocessing import LabelEncoder as _sklearn_LabelEncoder
18+
19+
from daal4py.sklearn._utils import sklearn_check_version
20+
21+
from ._array_api import get_namespace
22+
from .validation import _check_sample_weight
23+
24+
if not sklearn_check_version("1.7"):
25+
from sklearn.utils.class_weight import (
26+
compute_class_weight as _sklearn_compute_class_weight,
27+
)
28+
29+
def compute_class_weight(class_weight, *, classes, y, sample_weight=None):
30+
return _sklearn_compute_class_weight(class_weight, classes=classes, y=y)
31+
32+
else:
33+
from sklearn.utils.class_weight import compute_class_weight
34+
35+
36+
def _compute_class_weight(class_weight, *, classes, y, sample_weight=None):
37+
# this duplicates sklearn code in order to enable it for array API.
38+
# Note for the use of LabelEncoder this is only valid for sklearn
39+
# versions >= 1.6.
40+
xp, is_array_api_compliant = get_namespace(classes, y, sample_weight)
41+
42+
if not is_array_api_compliant:
43+
# use the sklearn version for standard use.
44+
return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight)
45+
46+
sety = xp.unique_values(y)
47+
setclasses = xp.unique_values(classes)
48+
if sety.shape[0] != xp.unique_values(xp.concat((sety, setclasses))).shape[0]:
49+
raise ValueError("classes should include all valid labels that can be in y")
50+
if class_weight is None or len(class_weight) == 0:
51+
# uniform class weights
52+
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
53+
elif class_weight == "balanced":
54+
if not sklearn_check_version("1.6"):
55+
raise RuntimeError(
56+
"array API support with 'balanced' keyword not supported for sklearn <1.6"
57+
)
58+
# Find the weight of each class as present in y.
59+
le = _sklearn_LabelEncoder()
60+
y_ind = le.fit_transform(y)
61+
if not all([item in le.classes_ for item in classes]):
62+
raise ValueError("classes should have valid labels that are in y")
63+
64+
sample_weight = _check_sample_weight(sample_weight, y)
65+
# scikit-learn implementation uses numpy.bincount, which does a combined
66+
# min and max search, only erroring when a value < 0. Replicating this
67+
# exactly via array API would cause another O(n) evaluation (by doing
68+
# min and max separately). However this check can be removed due to the
69+
# nature of the LabelEncoder. Therefore only the maximum is found, and
70+
# then core logic of bincount is replicated:
71+
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/compiled_base.c
72+
weighted_class_counts = xp.zeros(
73+
(xp.max(y_ind) + 1,), dtype=sample_weight.dtype, device=y.device
74+
)
75+
76+
# use a more GPU-friendly summation approach for collecting weighted_class_counts
77+
for w_idx in range(weighted_class_counts.shape[0]):
78+
weighted_class_counts[w_idx] = xp.sum(sample_weight[y_ind == w_idx])
79+
80+
recip_freq = xp.sum(weighted_class_counts) / (
81+
le.classes_.shape[0] * weighted_class_counts
82+
)
83+
84+
weight = xp.take(recip_freq, le.transform(classes))
85+
else:
86+
# user-defined dictionary
87+
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
88+
unweighted_classes = []
89+
for i, c in enumerate(classes):
90+
if (fc := float(c)) in class_weight:
91+
# array API has only numeric datatypes, convert to float for generality
92+
# complex values should never be observed by this function
93+
weight[i] = class_weight[fc]
94+
else:
95+
unweighted_classes.append(c)
96+
97+
n_weighted_classes = classes.shape[0] - len(unweighted_classes)
98+
if unweighted_classes and n_weighted_classes != len(class_weight):
99+
raise ValueError(
100+
f"The classes, {unweighted_classes}, are not in" " class_weight"
101+
)
102+
103+
return weight
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# ==============================================================================
2+
# Copyright contributors to the oneDAL Project
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
import numpy as np
18+
import pytest
19+
from sklearn.datasets import load_iris
20+
21+
from daal4py.sklearn._utils import sklearn_check_version
22+
from onedal.tests.utils._dataframes_support import (
23+
_as_numpy,
24+
_convert_to_dataframe,
25+
get_dataframes_and_queues,
26+
)
27+
from sklearnex import config_context
28+
from sklearnex.utils.class_weight import _compute_class_weight
29+
from sklearnex.utils.class_weight import compute_class_weight as sk_compute_class_weight
30+
31+
32+
@pytest.mark.skipif(not sklearn_check_version("1.6"), reason="lacks array API support")
33+
@pytest.mark.parametrize("class_weight", [None, "balanced", "ramp"])
34+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues("array_api,dpctl"))
35+
def test_compute_class_weight_array_api(class_weight, dataframe, queue):
36+
# This verifies that array_api functionality matches sklearn
37+
38+
_, y = load_iris(return_X_y=True)
39+
classes = np.unique(y)
40+
41+
y_xp = _convert_to_dataframe(y, target_df=dataframe, device=queue)
42+
classes_xp = _convert_to_dataframe(classes, target_df=dataframe, device=queue)
43+
44+
rng = np.random.default_rng(seed=42)
45+
46+
# support of sample_weights added in sklearn 1.7
47+
set_sample_weight = class_weight == "balanced" and sklearn_check_version("1.7")
48+
49+
sample_weight = rng.random(y.shape).astype(np.float64) if set_sample_weight else None
50+
51+
if class_weight == "ramp":
52+
class_weight = {int(i): int(i) for i in np.unique(y)}
53+
54+
weight_np = sk_compute_class_weight(
55+
class_weight, classes=classes, y=y, sample_weight=sample_weight
56+
)
57+
58+
if set_sample_weight:
59+
sample_weight = _convert_to_dataframe(
60+
sample_weight, target_df=dataframe, device=queue
61+
)
62+
63+
# evaluate custom sklearnex array API functionality
64+
with config_context(array_api_dispatch=True):
65+
weight_xp = _compute_class_weight(
66+
class_weight, classes=classes_xp, y=y_xp, sample_weight=sample_weight
67+
)
68+
69+
np.testing.assert_allclose(_as_numpy(weight_xp), weight_np)

0 commit comments

Comments
 (0)