-
Notifications
You must be signed in to change notification settings - Fork 184
[enh] add array API enabled compute_class_weight
function for use in classifier estimators
#2697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
3e0f1b0
Create class_weight.py
icfaust ad4dd21
Update class_weight.py
icfaust aa44be3
formatting
icfaust c4ab3f5
Update class_weight.py
icfaust cb4591c
Update class_weight.py
icfaust 3d926b7
formatting
icfaust dcd3367
another fix
icfaust 3041510
Update class_weight.py
icfaust b550a9c
add testing
icfaust 744275b
swap skipif
icfaust bede7f7
fix imports
icfaust fc17dd2
fix signature
icfaust 41b8163
fix issue
icfaust c6ef421
fix kwargs
icfaust 25ed9d1
fixes for sklearn 1.6
icfaust d4d04ea
Update test_class_weight.py
icfaust c1bfbde
remove hashable dependence
icfaust f30b237
change get_namespace call
icfaust 27bf1d0
fixes to unique checks
icfaust b3e20ce
fix another mistake
icfaust 65685ef
Update test_class_weight.py
icfaust 2e12ec1
Update class_weight.py
icfaust 48fbde2
Update class_weight.py
icfaust 310c53c
make review changes
icfaust 549901b
swap some testing
icfaust 797f89f
change to a generator
icfaust 8c11df1
Update test_class_weight.py
icfaust 8428614
Update class_weight.py
icfaust abfd5fc
formatting
icfaust 6f342e7
fix len interaction with array_api_strict
icfaust 5d92227
make check for sklearn 1.7 in testing
icfaust 7882c6b
another fix
icfaust File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# ============================================================================== | ||
# Copyright contributors to the oneDAL Project | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from sklearn.preprocessing import LabelEncoder as _sklearn_LabelEncoder | ||
|
||
from daal4py.sklearn._utils import sklearn_check_version | ||
|
||
from ._array_api import get_namespace | ||
from .validation import _check_sample_weight | ||
|
||
if not sklearn_check_version("1.7"): | ||
from sklearn.utils.class_weight import ( | ||
compute_class_weight as _sklearn_compute_class_weight, | ||
) | ||
|
||
def compute_class_weight(class_weight, *, classes, y, sample_weight=None): | ||
return _sklearn_compute_class_weight(class_weight, classes=classes, y=y) | ||
|
||
else: | ||
from sklearn.utils.class_weight import compute_class_weight | ||
|
||
|
||
def _compute_class_weight(class_weight, *, classes, y, sample_weight=None): | ||
# this duplicates sklearn code in order to enable it for array API. | ||
# Note for the use of LabelEncoder this is only valid for sklearn | ||
# versions >= 1.6. | ||
xp, is_array_api_compliant = get_namespace(classes, y, sample_weight) | ||
|
||
if not is_array_api_compliant: | ||
# use the sklearn version for standard use. | ||
return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight) | ||
|
||
sety = xp.unique_values(y) | ||
setclasses = xp.unique_values(classes) | ||
if len(sety) != len(xp.unique_values(xp.concat((sety, setclasses)))): | ||
raise ValueError("classes should include all valid labels that can be in y") | ||
if class_weight is None or len(class_weight) == 0: | ||
# uniform class weights | ||
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device) | ||
elif class_weight == "balanced": | ||
if not sklearn_check_version("1.6"): | ||
raise RuntimeError( | ||
"array API support with 'balanced' keyword not supported for sklearn <1.6" | ||
) | ||
# Find the weight of each class as present in y. | ||
le = _sklearn_LabelEncoder() | ||
y_ind = le.fit_transform(y) | ||
if not all([item in le.classes_ for item in classes]): | ||
raise ValueError("classes should have valid labels that are in y") | ||
|
||
sample_weight = _check_sample_weight(sample_weight, y) | ||
# scikit-learn implementation uses numpy.bincount, which does a combined | ||
# min and max search, only erroring when a value < 0. Replicating this | ||
# exactly via array API would cause another O(n) evaluation (by doing | ||
# min and max separately). However this check can be removed due to the | ||
# nature of the LabelEncoder. Therefore only the maximum is found, and | ||
# then core logic of bincount is replicated: | ||
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/compiled_base.c | ||
weighted_class_counts = xp.zeros( | ||
(xp.max(y_ind) + 1,), dtype=sample_weight.dtype, device=y.device | ||
) | ||
|
||
# use a more GPU-friendly summation approach for collecting weighted_class_counts | ||
for w_idx in range(len(weighted_class_counts)): | ||
weighted_class_counts[w_idx] = xp.sum(sample_weight[y_ind == w_idx]) | ||
|
||
recip_freq = xp.sum(weighted_class_counts) / ( | ||
Vika-F marked this conversation as resolved.
Show resolved
Hide resolved
david-cortes-intel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
len(le.classes_) * weighted_class_counts | ||
) | ||
|
||
weight = xp.take(recip_freq, le.transform(classes)) | ||
else: | ||
# user-defined dictionary | ||
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device) | ||
unweighted_classes = [] | ||
for i, c in enumerate(classes): | ||
if (fc := float(c)) in class_weight: | ||
# array API has only numeric datatypes, convert to float for generality | ||
# complex values should never be observed by this function | ||
weight[i] = class_weight[fc] | ||
else: | ||
unweighted_classes.append(c) | ||
|
||
n_weighted_classes = len(classes) - len(unweighted_classes) | ||
if unweighted_classes and n_weighted_classes != len(class_weight): | ||
raise ValueError( | ||
f"The classes, {unweighted_classes}, are not in" " class_weight" | ||
) | ||
|
||
return weight |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# ============================================================================== | ||
# Copyright contributors to the oneDAL Project | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import numpy as np | ||
import pytest | ||
from sklearn.datasets import load_iris | ||
|
||
from daal4py.sklearn._utils import sklearn_check_version | ||
from onedal.tests.utils._dataframes_support import ( | ||
_as_numpy, | ||
_convert_to_dataframe, | ||
get_dataframes_and_queues, | ||
) | ||
from sklearnex import config_context | ||
from sklearnex.utils.class_weight import _compute_class_weight | ||
from sklearnex.utils.class_weight import compute_class_weight as sk_compute_class_weight | ||
|
||
|
||
@pytest.mark.skipif(not sklearn_check_version("1.6"), reason="lacks array API support") | ||
@pytest.mark.parametrize("class_weight", [None, "balanced", "ramp"]) | ||
@pytest.mark.parametrize( | ||
"dataframe,queue", get_dataframes_and_queues("array_api,dpctl") | ||
) | ||
def test_compute_class_weight_array_api(class_weight, dataframe, queue): | ||
# This verifies that array_api functionality matches sklearn | ||
|
||
_, y = load_iris(return_X_y=True) | ||
classes = np.unique(y) | ||
|
||
y_xp = _convert_to_dataframe(y, target_df=dataframe, device=queue) | ||
classes_xp = _convert_to_dataframe(classes, target_df=dataframe, device=queue) | ||
|
||
rng = np.random.default_rng(seed=42) | ||
|
||
sample_weight = ( | ||
rng.random(y.shape).astype(np.float64) if class_weight == "balanced" else None | ||
) | ||
|
||
if class_weight == "ramp": | ||
class_weight = {int(i): int(i) for i in np.unique(y)} | ||
|
||
weight_np = sk_compute_class_weight( | ||
class_weight, classes=classes, y=y, sample_weight=sample_weight | ||
) | ||
|
||
if class_weight == "balanced": | ||
sample_weight = _convert_to_dataframe( | ||
sample_weight, target_df=dataframe, device=queue | ||
) | ||
|
||
# evaluate custom sklearnex array API functionality | ||
with config_context(array_api_dispatch=True): | ||
weight_xp = _compute_class_weight( | ||
class_weight, classes=classes_xp, y=y_xp, sample_weight=sample_weight | ||
) | ||
|
||
np.testing.assert_allclose(_as_numpy(weight_xp), weight_np) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.