Skip to content

Commit ad55af6

Browse files
authored
Fix test patching for SVC predict proba (#808) (#815)
1 parent ecba8c3 commit ad55af6

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

sklearnex/svm/nusvc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sklearn.svm import NuSVC as sklearn_NuSVC
2121
from sklearn.utils.validation import _deprecate_positional_args
2222
from sklearn.exceptions import NotFittedError
23+
from sklearn import __version__ as sklearn_version
24+
from distutils.version import LooseVersion
2325

2426
from onedal.svm import NuSVC as onedal_NuSVC
2527

@@ -53,11 +55,20 @@ def predict(self, X):
5355
'sklearn': sklearn_NuSVC.predict,
5456
}, X)
5557

58+
@property
59+
def predict_proba(self):
60+
self._check_proba()
61+
return self._predict_proba
62+
5663
@wrap_output_data
5764
def _predict_proba(self, X):
58-
return dispatch(self, 'svm.NuSVC._predict_proba', {
65+
sklearn_pred_proba = (sklearn_NuSVC.predict_proba
66+
if LooseVersion(sklearn_version) >= LooseVersion("1.0")
67+
else sklearn_NuSVC._predict_proba)
68+
69+
return dispatch(self, 'svm.NuSVC.predict_proba', {
5970
'onedal': self.__class__._onedal_predict_proba,
60-
'sklearn': sklearn_NuSVC._predict_proba,
71+
'sklearn': sklearn_pred_proba,
6172
}, X)
6273

6374
@wrap_output_data
@@ -74,7 +85,7 @@ def _onedal_cpu_supported(self, method_name, *data):
7485
if method_name == 'svm.NuSVC.fit':
7586
return self.kernel in ['linear', 'rbf', 'poly', 'sigmoid']
7687
if method_name in ['svm.NuSVC.predict',
77-
'svm.NuSVC._predict_proba',
88+
'svm.NuSVC.predict_proba',
7889
'svm.NuSVC.decision_function']:
7990
return hasattr(self, '_onedal_estimator')
8091

sklearnex/svm/svc.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sklearn.svm import SVC as sklearn_SVC
2121
from sklearn.utils.validation import _deprecate_positional_args
2222
from sklearn.exceptions import NotFittedError
23+
from sklearn import __version__ as sklearn_version
24+
from distutils.version import LooseVersion
2325

2426
from onedal.svm import SVC as onedal_SVC
2527

@@ -52,11 +54,20 @@ def predict(self, X):
5254
'sklearn': sklearn_SVC.predict,
5355
}, X)
5456

57+
@property
58+
def predict_proba(self):
59+
self._check_proba()
60+
return self._predict_proba
61+
5562
@wrap_output_data
5663
def _predict_proba(self, X):
57-
return dispatch(self, 'svm.SVC._predict_proba', {
64+
sklearn_pred_proba = (sklearn_SVC.predict_proba
65+
if LooseVersion(sklearn_version) >= LooseVersion("1.0")
66+
else sklearn_SVC._predict_proba)
67+
68+
return dispatch(self, 'svm.SVC.predict_proba', {
5869
'onedal': self.__class__._onedal_predict_proba,
59-
'sklearn': sklearn_SVC._predict_proba,
70+
'sklearn': sklearn_pred_proba,
6071
}, X)
6172

6273
@wrap_output_data
@@ -79,7 +90,7 @@ def _onedal_gpu_supported(self, method_name, *data):
7990
hasattr(self, '_class_count') and self._class_count == 2 and \
8091
hasattr(self, '_is_sparse') and not self._is_sparse
8192
if method_name in ['svm.SVC.predict',
82-
'svm.SVC._predict_proba',
93+
'svm.SVC.predict_proba',
8394
'svm.SVC.decision_function']:
8495
return hasattr(self, '_onedal_estimator') and \
8596
self._onedal_gpu_supported('svm.SVC.fit', *data)
@@ -89,7 +100,7 @@ def _onedal_cpu_supported(self, method_name, *data):
89100
if method_name == 'svm.SVC.fit':
90101
return self.kernel in ['linear', 'rbf', 'poly', 'sigmoid']
91102
if method_name in ['svm.SVC.predict',
92-
'svm.SVC._predict_proba',
103+
'svm.SVC.predict_proba',
93104
'svm.SVC.decision_function']:
94105
return hasattr(self, '_onedal_estimator')
95106
raise RuntimeError(f'Unknown method {method_name} in {self.__class__.__name__}')

0 commit comments

Comments
 (0)