Skip to content

Commit 4872a8e

Browse files
Alexsandrussnapetrov
authored andcommitted
Fix for balanced class weight (#1080)
* add balanced branch of _compute_class_weight * Remove extra computation of weights
1 parent 958c519 commit 4872a8e

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

onedal/datatypes/validation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import warnings
1919
from scipy import sparse as sp
2020
from scipy.sparse import issparse, dok_matrix, lil_matrix
21+
from sklearn.preprocessing import LabelEncoder
2122
from collections.abc import Sequence
2223
from numbers import Integral
2324

@@ -57,7 +58,15 @@ def _compute_class_weight(class_weight, classes, y):
5758
if class_weight is None or len(class_weight) == 0:
5859
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
5960
elif class_weight == 'balanced':
60-
weight = None
61+
y_ = _column_or_1d(y)
62+
classes, _ = np.unique(y_, return_inverse=True)
63+
64+
le = LabelEncoder()
65+
y_ind = le.fit_transform(y_)
66+
if not all(np.in1d(classes, le.classes_)):
67+
raise ValueError("classes should have valid labels that are in y")
68+
69+
weight = len(y_) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64))
6170
else:
6271
# user-defined dictionary
6372
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')

sklearnex/svm/nusvc.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
195195

196196
self._onedal_estimator = onedal_NuSVC(**onedal_params)
197197
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
198-
199-
if self.class_weight == 'balanced':
200-
self.class_weight_ = self._compute_balanced_class_weight(y)
201-
else:
202-
self.class_weight_ = self._onedal_estimator.class_weight_
198+
self.class_weight_ = self._onedal_estimator.class_weight_
203199

204200
if self.probability:
205201
self._fit_proba(X, y, sample_weight, queue=queue)

sklearnex/svm/svc.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
209209

210210
self._onedal_estimator = onedal_SVC(**onedal_params)
211211
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
212-
213-
if self.class_weight == 'balanced':
214-
self.class_weight_ = self._compute_balanced_class_weight(y)
215-
else:
216-
self.class_weight_ = self._onedal_estimator.class_weight_
212+
self.class_weight_ = self._onedal_estimator.class_weight_
217213

218214
if self.probability:
219215
self._fit_proba(X, y, sample_weight, queue=queue)

0 commit comments

Comments
 (0)