|
| 1 | +# ============================================================================== |
| 2 | +# Copyright contributors to the oneDAL Project |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# ============================================================================== |
| 16 | + |
| 17 | +from sklearn.preprocessing import LabelEncoder as _sklearn_LabelEncoder |
| 18 | + |
| 19 | +from daal4py.sklearn._utils import sklearn_check_version |
| 20 | + |
| 21 | +from ._array_api import get_namespace |
| 22 | +from .validation import _check_sample_weight |
| 23 | + |
| 24 | +if not sklearn_check_version("1.7"): |
| 25 | + from sklearn.utils.class_weight import ( |
| 26 | + compute_class_weight as _sklearn_compute_class_weight, |
| 27 | + ) |
| 28 | + |
| 29 | + def compute_class_weight(class_weight, *, classes, y, sample_weight=None): |
| 30 | + return _sklearn_compute_class_weight(class_weight, classes=classes, y=y) |
| 31 | + |
| 32 | +else: |
| 33 | + from sklearn.utils.class_weight import compute_class_weight |
| 34 | + |
| 35 | + |
| 36 | +def _compute_class_weight(class_weight, *, classes, y, sample_weight=None): |
| 37 | + # this duplicates sklearn code in order to enable it for array API. |
| 38 | + # Note for the use of LabelEncoder this is only valid for sklearn |
| 39 | + # versions >= 1.6. |
| 40 | + xp, is_array_api_compliant = get_namespace(classes, y, sample_weight) |
| 41 | + |
| 42 | + if not is_array_api_compliant: |
| 43 | + # use the sklearn version for standard use. |
| 44 | + return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight) |
| 45 | + |
| 46 | + sety = xp.unique_values(y) |
| 47 | + setclasses = xp.unique_values(classes) |
| 48 | + if sety.shape[0] != xp.unique_values(xp.concat((sety, setclasses))).shape[0]: |
| 49 | + raise ValueError("classes should include all valid labels that can be in y") |
| 50 | + if class_weight is None or len(class_weight) == 0: |
| 51 | + # uniform class weights |
| 52 | + weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device) |
| 53 | + elif class_weight == "balanced": |
| 54 | + if not sklearn_check_version("1.6"): |
| 55 | + raise RuntimeError( |
| 56 | + "array API support with 'balanced' keyword not supported for sklearn <1.6" |
| 57 | + ) |
| 58 | + # Find the weight of each class as present in y. |
| 59 | + le = _sklearn_LabelEncoder() |
| 60 | + y_ind = le.fit_transform(y) |
| 61 | + if not all([item in le.classes_ for item in classes]): |
| 62 | + raise ValueError("classes should have valid labels that are in y") |
| 63 | + |
| 64 | + sample_weight = _check_sample_weight(sample_weight, y) |
| 65 | + # scikit-learn implementation uses numpy.bincount, which does a combined |
| 66 | + # min and max search, only erroring when a value < 0. Replicating this |
| 67 | + # exactly via array API would cause another O(n) evaluation (by doing |
| 68 | + # min and max separately). However this check can be removed due to the |
| 69 | + # nature of the LabelEncoder. Therefore only the maximum is found, and |
| 70 | + # then core logic of bincount is replicated: |
| 71 | + # https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/compiled_base.c |
| 72 | + weighted_class_counts = xp.zeros( |
| 73 | + (xp.max(y_ind) + 1,), dtype=sample_weight.dtype, device=y.device |
| 74 | + ) |
| 75 | + |
| 76 | + # use a more GPU-friendly summation approach for collecting weighted_class_counts |
| 77 | + for w_idx in range(weighted_class_counts.shape[0]): |
| 78 | + weighted_class_counts[w_idx] = xp.sum(sample_weight[y_ind == w_idx]) |
| 79 | + |
| 80 | + recip_freq = xp.sum(weighted_class_counts) / ( |
| 81 | + le.classes_.shape[0] * weighted_class_counts |
| 82 | + ) |
| 83 | + |
| 84 | + weight = xp.take(recip_freq, le.transform(classes)) |
| 85 | + else: |
| 86 | + # user-defined dictionary |
| 87 | + weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device) |
| 88 | + unweighted_classes = [] |
| 89 | + for i, c in enumerate(classes): |
| 90 | + if (fc := float(c)) in class_weight: |
| 91 | + # array API has only numeric datatypes, convert to float for generality |
| 92 | + # complex values should never be observed by this function |
| 93 | + weight[i] = class_weight[fc] |
| 94 | + else: |
| 95 | + unweighted_classes.append(c) |
| 96 | + |
| 97 | + n_weighted_classes = classes.shape[0] - len(unweighted_classes) |
| 98 | + if unweighted_classes and n_weighted_classes != len(class_weight): |
| 99 | + raise ValueError( |
| 100 | + f"The classes, {unweighted_classes}, are not in" " class_weight" |
| 101 | + ) |
| 102 | + |
| 103 | + return weight |
0 commit comments