2020from sklearn .svm import SVC as sklearn_SVC
2121from sklearn .utils .validation import _deprecate_positional_args
2222from sklearn .exceptions import NotFittedError
23+ from sklearn import __version__ as sklearn_version
24+ from distutils .version import LooseVersion
2325
2426from onedal .svm import SVC as onedal_SVC
2527
@@ -52,11 +54,20 @@ def predict(self, X):
5254 'sklearn' : sklearn_SVC .predict ,
5355 }, X )
5456
57+ @property
58+ def predict_proba (self ):
59+ self ._check_proba ()
60+ return self ._predict_proba
61+
5562 @wrap_output_data
5663 def _predict_proba (self , X ):
57- return dispatch (self , 'svm.SVC._predict_proba' , {
64+ sklearn_pred_proba = (sklearn_SVC .predict_proba
65+ if LooseVersion (sklearn_version ) >= LooseVersion ("1.0" )
66+ else sklearn_SVC ._predict_proba )
67+
68+ return dispatch (self , 'svm.SVC.predict_proba' , {
5869 'onedal' : self .__class__ ._onedal_predict_proba ,
59- 'sklearn' : sklearn_SVC . _predict_proba ,
70+ 'sklearn' : sklearn_pred_proba ,
6071 }, X )
6172
6273 @wrap_output_data
@@ -79,7 +90,7 @@ def _onedal_gpu_supported(self, method_name, *data):
7990 hasattr (self , '_class_count' ) and self ._class_count == 2 and \
8091 hasattr (self , '_is_sparse' ) and not self ._is_sparse
8192 if method_name in ['svm.SVC.predict' ,
82- 'svm.SVC._predict_proba ' ,
93+ 'svm.SVC.predict_proba ' ,
8394 'svm.SVC.decision_function' ]:
8495 return hasattr (self , '_onedal_estimator' ) and \
8596 self ._onedal_gpu_supported ('svm.SVC.fit' , * data )
@@ -89,7 +100,7 @@ def _onedal_cpu_supported(self, method_name, *data):
89100 if method_name == 'svm.SVC.fit' :
90101 return self .kernel in ['linear' , 'rbf' , 'poly' , 'sigmoid' ]
91102 if method_name in ['svm.SVC.predict' ,
92- 'svm.SVC._predict_proba ' ,
103+ 'svm.SVC.predict_proba ' ,
93104 'svm.SVC.decision_function' ]:
94105 return hasattr (self , '_onedal_estimator' )
95106 raise RuntimeError (f'Unknown method { method_name } in { self .__class__ .__name__ } ' )
0 commit comments