Skip to content

Commit 14849ee

Browse files
authored
Fix svm attributes: class_weight_, n_iter_ (#1100)
* Fix svm attributes: class_weight_, n_iter_ * Fix [nu]svc n_iter_
1 parent 4ac8f47 commit 14849ee

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

sklearnex/svm/_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _save_attributes(self):
107107
self.dual_coef_ = self._onedal_estimator.dual_coef_
108108
self.shape_fit_ = self._onedal_estimator.class_weight_
109109
self.classes_ = self._onedal_estimator.classes_
110+
self.class_weight_ = self._onedal_estimator.class_weight_
110111
self.support_ = self._onedal_estimator.support_
111112

112113
self._intercept_ = self._onedal_estimator.intercept_
@@ -129,6 +130,10 @@ def _save_attributes(self):
129130
self.intercept_ = self._intercept_
130131
self._is_in_fit = False
131132

133+
if Version(sklearn_version) >= Version("1.1"):
134+
length = int(len(self.classes_) * (len(self.classes_) - 1) / 2)
135+
self.n_iter_ = np.full((length, ), self._onedal_estimator.n_iter_)
136+
132137

133138
class BaseSVR(ABC):
134139
def _save_attributes(self):
@@ -153,3 +158,6 @@ def _save_attributes(self):
153158
self._dual_coef_ = self.dual_coef_
154159
self.intercept_ = self._intercept_
155160
self._is_in_fit = False
161+
162+
if Version(sklearn_version) >= Version("1.1"):
163+
self.n_iter_ = self._onedal_estimator.n_iter_

sklearnex/svm/nusvc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
195195

196196
self._onedal_estimator = onedal_NuSVC(**onedal_params)
197197
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
198-
self.class_weight_ = self._onedal_estimator.class_weight_
199198

200199
if self.probability:
201200
self._fit_proba(X, y, sample_weight, queue=queue)

sklearnex/svm/svc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
209209

210210
self._onedal_estimator = onedal_SVC(**onedal_params)
211211
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
212-
self.class_weight_ = self._onedal_estimator.class_weight_
213212

214213
if self.probability:
215214
self._fit_proba(X, y, sample_weight, queue=queue)

0 commit comments

Comments
 (0)