Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3e0f1b0
Create class_weight.py
icfaust Sep 23, 2025
ad4dd21
Update class_weight.py
icfaust Sep 23, 2025
aa44be3
formatting
icfaust Sep 23, 2025
c4ab3f5
Update class_weight.py
icfaust Sep 23, 2025
cb4591c
Update class_weight.py
icfaust Sep 23, 2025
3d926b7
formatting
icfaust Sep 23, 2025
dcd3367
another fix
icfaust Sep 23, 2025
3041510
Update class_weight.py
icfaust Sep 24, 2025
b550a9c
add testing
icfaust Sep 24, 2025
744275b
swap skipif
icfaust Sep 24, 2025
bede7f7
fix imports
icfaust Sep 24, 2025
fc17dd2
fix signature
icfaust Sep 24, 2025
41b8163
fix issue
icfaust Sep 24, 2025
c6ef421
fix kwargs
icfaust Sep 24, 2025
25ed9d1
fixes for sklearn 1.6
icfaust Sep 24, 2025
d4d04ea
Update test_class_weight.py
icfaust Sep 24, 2025
c1bfbde
remove hashable dependence
icfaust Sep 24, 2025
f30b237
change get_namespace call
icfaust Sep 24, 2025
27bf1d0
fixes to unique checks
icfaust Sep 24, 2025
b3e20ce
fix another mistake
icfaust Sep 24, 2025
65685ef
Update test_class_weight.py
icfaust Sep 24, 2025
2e12ec1
Update class_weight.py
icfaust Sep 24, 2025
48fbde2
Update class_weight.py
icfaust Sep 24, 2025
310c53c
make review changes
icfaust Sep 30, 2025
549901b
swap some testing
icfaust Sep 30, 2025
797f89f
change to a generator
icfaust Sep 30, 2025
8c11df1
Update test_class_weight.py
icfaust Sep 30, 2025
8428614
Update class_weight.py
icfaust Oct 1, 2025
abfd5fc
formatting
icfaust Oct 1, 2025
6f342e7
fix len interaction with array_api_strict
icfaust Oct 1, 2025
5d92227
make check for sklearn 1.7 in testing
icfaust Oct 1, 2025
7882c6b
another fix
icfaust Oct 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions sklearnex/utils/class_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# ==============================================================================
# Copyright contributors to the oneDAL Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from sklearn.preprocessing import LabelEncoder as _sklearn_LabelEncoder

from daal4py.sklearn._utils import sklearn_check_version

from ._array_api import get_namespace
from .validation import _check_sample_weight

if not sklearn_check_version("1.7"):
from sklearn.utils.class_weight import (
compute_class_weight as _sklearn_compute_class_weight,
)

def compute_class_weight(class_weight, *, classes, y, sample_weight=None):
return _sklearn_compute_class_weight(class_weight, classes=classes, y=y)

else:
from sklearn.utils.class_weight import compute_class_weight


def _compute_class_weight(class_weight, *, classes, y, sample_weight=None):
# this duplicates sklearn code in order to enable it for array API.
# Note for the use of LabelEncoder this is only valid for sklearn
# versions >= 1.6.
xp, is_array_api_compliant = get_namespace(classes, y, sample_weight)

if not is_array_api_compliant:
# use the sklearn version for standard use.
return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight)

sety = xp.unique_values(y)
setclasses = xp.unique_values(classes)
if sety.shape[0] != xp.unique_values(xp.concat((sety, setclasses))).shape[0]:
raise ValueError("classes should include all valid labels that can be in y")
if class_weight is None or len(class_weight) == 0:
# uniform class weights
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
elif class_weight == "balanced":
if not sklearn_check_version("1.6"):
raise RuntimeError(
"array API support with 'balanced' keyword not supported for sklearn <1.6"
)
# Find the weight of each class as present in y.
le = _sklearn_LabelEncoder()
y_ind = le.fit_transform(y)
if not all([item in le.classes_ for item in classes]):
raise ValueError("classes should have valid labels that are in y")

sample_weight = _check_sample_weight(sample_weight, y)
# scikit-learn implementation uses numpy.bincount, which does a combined
# min and max search, only erroring when a value < 0. Replicating this
# exactly via array API would cause another O(n) evaluation (by doing
# min and max separately). However this check can be removed due to the
# nature of the LabelEncoder. Therefore only the maximum is found, and
# then core logic of bincount is replicated:
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/compiled_base.c
weighted_class_counts = xp.zeros(
(xp.max(y_ind) + 1,), dtype=sample_weight.dtype, device=y.device
)

# use a more GPU-friendly summation approach for collecting weighted_class_counts
for w_idx in range(weighted_class_counts.shape[0]):
weighted_class_counts[w_idx] = xp.sum(sample_weight[y_ind == w_idx])

recip_freq = xp.sum(weighted_class_counts) / (
le.classes_.shape[0] * weighted_class_counts
)

weight = xp.take(recip_freq, le.transform(classes))
else:
# user-defined dictionary
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
unweighted_classes = []
for i, c in enumerate(classes):
if (fc := float(c)) in class_weight:
# array API has only numeric datatypes, convert to float for generality
# complex values should never be observed by this function
weight[i] = class_weight[fc]
else:
unweighted_classes.append(c)

n_weighted_classes = classes.shape[0] - len(unweighted_classes)
if unweighted_classes and n_weighted_classes != len(class_weight):
raise ValueError(
f"The classes, {unweighted_classes}, are not in" " class_weight"
)

return weight
69 changes: 69 additions & 0 deletions sklearnex/utils/tests/test_class_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# ==============================================================================
# Copyright contributors to the oneDAL Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
import pytest
from sklearn.datasets import load_iris

from daal4py.sklearn._utils import sklearn_check_version
from onedal.tests.utils._dataframes_support import (
_as_numpy,
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex import config_context
from sklearnex.utils.class_weight import _compute_class_weight
from sklearnex.utils.class_weight import compute_class_weight as sk_compute_class_weight


@pytest.mark.skipif(not sklearn_check_version("1.6"), reason="lacks array API support")
@pytest.mark.parametrize("class_weight", [None, "balanced", "ramp"])
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues("array_api,dpctl"))
def test_compute_class_weight_array_api(class_weight, dataframe, queue):
# This verifies that array_api functionality matches sklearn

_, y = load_iris(return_X_y=True)
classes = np.unique(y)

y_xp = _convert_to_dataframe(y, target_df=dataframe, device=queue)
classes_xp = _convert_to_dataframe(classes, target_df=dataframe, device=queue)

rng = np.random.default_rng(seed=42)

# support of sample_weights added in sklearn 1.7
set_sample_weight = class_weight == "balanced" and sklearn_check_version("1.7")

sample_weight = rng.random(y.shape).astype(np.float64) if set_sample_weight else None

if class_weight == "ramp":
class_weight = {int(i): int(i) for i in np.unique(y)}

weight_np = sk_compute_class_weight(
class_weight, classes=classes, y=y, sample_weight=sample_weight
)

if set_sample_weight:
sample_weight = _convert_to_dataframe(
sample_weight, target_df=dataframe, device=queue
)

# evaluate custom sklearnex array API functionality
with config_context(array_api_dispatch=True):
weight_xp = _compute_class_weight(
class_weight, classes=classes_xp, y=y_xp, sample_weight=sample_weight
)

np.testing.assert_allclose(_as_numpy(weight_xp), weight_np)
Loading