Skip to content

Commit 174ee31

Browse files
authored
Fix svm for scikit-learn 1.0.1 (#837)
1 parent ae4a1ca commit 174ee31

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

.ci/pipeline/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
variables:
1818
DESCRIPTION: ReleaseTesting
19-
DESELECT: --deselect ::test_svc_clone_with_callable_kernel --deselect ::test_precomputed --deselect ::test_tweak_params --deselect ::test_probability --deselect ::test_custom_kernel_not_array_input --deselect ::test_unicode_kernel --deselect ::test_consistent_proba
19+
DESELECT: --deselect ::test_svc_clone_with_callable_kernel --deselect ::test_precomputed --deselect ::test_tweak_params --deselect ::test_probability --deselect ::test_custom_kernel_not_array_input --deselect ::test_unicode_kernel --deselect ::test_consistent_proba --deselect ::test_svc_raises_error_internal_representation
2020
TEST_COMMAND: python -m sklearnex -m pytest -ra --disable-warnings --pyargs sklearn.svm.tests.test_svm $(DESELECT)
2121

2222
jobs:

onedal/svm/svm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def _predict(self, X, module, queue):
247247
raise ValueError("break_ties must be False when "
248248
"decision_function_shape is 'ovo'")
249249

250+
if (module in [_backend.svm.classification, _backend.svm.nu_classification]):
251+
sv = self.support_vectors_
252+
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
253+
raise ValueError("The internal representation "
254+
f"of {self.__class__.__name__} was altered")
255+
250256
if self.break_ties and self.decision_function_shape == 'ovr' and \
251257
len(self.classes_) > 2:
252258
y = np.argmax(self.decision_function(X), axis=1)
@@ -310,6 +316,12 @@ def _decision_function(self, X, module, queue):
310316
"cannot use sparse input in %r trained on dense data"
311317
% type(self).__name__)
312318

319+
if (module in [_backend.svm.classification, _backend.svm.nu_classification]):
320+
sv = self.support_vectors_
321+
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
322+
raise ValueError("The internal representation "
323+
f"of {self.__class__.__name__} was altered")
324+
313325
policy = _get_policy(queue, X)
314326
params = self._get_onedal_params(X)
315327

0 commit comments

Comments
 (0)