@@ -110,6 +110,8 @@ def __init__(self, n_neighbors=5, *,
110110 n_jobs = n_jobs , ** kwargs )
111111
112112 def fit (self , X , y ):
113+ if Version (sklearn_version ) >= Version ("1.0" ):
114+ self ._check_feature_names (X , reset = True )
113115 if self .metric_params is not None and 'p' in self .metric_params :
114116 if self .p is not None :
115117 warnings .warn ("Parameter p is found in metric_params. "
@@ -224,6 +226,8 @@ def fit(self, X, y):
224226 @wrap_output_data
225227 def predict (self , X ):
226228 check_is_fitted (self )
229+ if Version (sklearn_version ) >= Version ("1.0" ):
230+ self ._check_feature_names (X , reset = False )
227231 return dispatch (self , 'neighbors.KNeighborsClassifier.predict' , {
228232 'onedal' : self .__class__ ._onedal_predict ,
229233 'sklearn' : sklearn_KNeighborsClassifier .predict ,
@@ -232,6 +236,8 @@ def predict(self, X):
232236 @wrap_output_data
233237 def predict_proba (self , X ):
234238 check_is_fitted (self )
239+ if Version (sklearn_version ) >= Version ("1.0" ):
240+ self ._check_feature_names (X , reset = False )
235241 return dispatch (self , 'neighbors.KNeighborsClassifier.predict_proba' , {
236242 'onedal' : self .__class__ ._onedal_predict_proba ,
237243 'sklearn' : sklearn_KNeighborsClassifier .predict_proba ,
@@ -240,6 +246,8 @@ def predict_proba(self, X):
240246 @wrap_output_data
241247 def kneighbors (self , X = None , n_neighbors = None , return_distance = True ):
242248 check_is_fitted (self )
249+ if Version (sklearn_version ) >= Version ("1.0" ):
250+ self ._check_feature_names (X , reset = False )
243251 return dispatch (self , 'neighbors.KNeighborsClassifier.kneighbors' , {
244252 'onedal' : self .__class__ ._onedal_kneighbors ,
245253 'sklearn' : sklearn_KNeighborsClassifier .kneighbors ,
@@ -267,85 +275,98 @@ def radius_neighbors(self, X=None, radius=None, return_distance=True,
267275
268276 def _onedal_gpu_supported (self , method_name , * data ):
269277 X_incorrect_type = isinstance (data [0 ], (KDTree , BallTree , sklearn_NeighborsBase ))
270- if not X_incorrect_type :
271- if self ._fit_method in ['auto' , 'ball_tree' ]:
272- condition = self .n_neighbors is not None and \
273- self .n_neighbors >= self .n_samples_fit_ // 2
274- if self .n_features_in_ > 11 or condition :
275- result_method = 'brute'
276- else :
277- if self .metric in ['euclidean' ]:
278- result_method = 'kd_tree'
279- else :
280- result_method = 'brute'
278+
279+ if X_incorrect_type :
280+ return False
281+
282+ if self ._fit_method in ['auto' , 'ball_tree' ]:
283+ condition = self .n_neighbors is not None and \
284+ self .n_neighbors >= self .n_samples_fit_ // 2
285+ if self .n_features_in_ > 11 or condition :
286+ result_method = 'brute'
281287 else :
282- result_method = self ._fit_method
283- if method_name == 'neighbors.KNeighborsClassifier.fit' :
284- if X_incorrect_type :
285- return False
286- is_sparse = sp .isspmatrix (data [0 ])
287- class_count = None
288+ if self .effective_metric_ in ['euclidean' ]:
289+ result_method = 'kd_tree'
290+ else :
291+ result_method = 'brute'
292+ else :
293+ result_method = self ._fit_method
294+
295+ is_sparse = sp .isspmatrix (data [0 ])
296+ is_single_output = False
297+ class_count = 1
298+ if len (data ) > 1 or hasattr (self , '_onedal_estimator' ):
299+ # To check multioutput, might be overhead
288300 if len (data ) > 1 :
289- class_count = len (np .unique (data [1 ]))
290- # To check multioutput, might be overhead
291301 y = np .asarray (data [1 ])
292- is_single_output = y .ndim == 1 or y .ndim == 2 and y .shape [1 ] == 1
293- return result_method in ['brute' ] and \
294- self .effective_metric_ in ['manhattan' ,
295- 'minkowski' ,
296- 'euclidean' ,
297- 'chebyshev' ,
298- 'cosine' ] and \
299- class_count >= 2 and \
300- not is_sparse and \
301- is_single_output
302+ class_count = len (np .unique (y ))
303+ if hasattr (self , '_onedal_estimator' ):
304+ y = self ._onedal_estimator ._y
305+ is_single_output = y .ndim == 1 or y .ndim == 2 and y .shape [1 ] == 1
306+ is_valid_for_brute = result_method in ['brute' ] and \
307+ self .effective_metric_ in ['manhattan' ,
308+ 'minkowski' ,
309+ 'euclidean' ]
310+ is_valid_weights = self .weights in ['uniform' , "distance" ]
311+ main_condition = is_valid_for_brute and not is_sparse and \
312+ is_single_output and is_valid_weights
313+
314+ if method_name == 'neighbors.KNeighborsClassifier.fit' :
315+ return main_condition and class_count >= 2
302316 if method_name in ['neighbors.KNeighborsClassifier.predict' ,
303317 'neighbors.KNeighborsClassifier.predict_proba' ,
304318 'neighbors.KNeighborsClassifier.kneighbors' ]:
305- return hasattr (self , '_onedal_estimator' ) and not sp . isspmatrix ( data [ 0 ] )
319+ return main_condition and hasattr (self , '_onedal_estimator' )
306320 raise RuntimeError (f'Unknown method { method_name } in { self .__class__ .__name__ } ' )
307321
308322 def _onedal_cpu_supported (self , method_name , * data ):
309323 X_incorrect_type = isinstance (data [0 ], (KDTree , BallTree , sklearn_NeighborsBase ))
310- if not X_incorrect_type :
311- if self ._fit_method in ['auto' , 'ball_tree' ]:
312- condition = self .n_neighbors is not None and \
313- self .n_neighbors >= self .n_samples_fit_ // 2
314- if self .n_features_in_ > 11 or condition :
315- result_method = 'brute'
316- else :
317- if self .metric in ['euclidean' ]:
318- result_method = 'kd_tree'
319- else :
320- result_method = 'brute'
324+
325+ if X_incorrect_type :
326+ return False
327+
328+ if self ._fit_method in ['auto' , 'ball_tree' ]:
329+ condition = self .n_neighbors is not None and \
330+ self .n_neighbors >= self .n_samples_fit_ // 2
331+ if self .n_features_in_ > 11 or condition :
332+ result_method = 'brute'
321333 else :
322- result_method = self ._fit_method
323- if method_name == 'neighbors.KNeighborsClassifier.fit' :
324- if X_incorrect_type :
325- return False
326- is_sparse = sp .isspmatrix (data [0 ])
327- class_count = None
334+ if self .effective_metric_ in ['euclidean' ]:
335+ result_method = 'kd_tree'
336+ else :
337+ result_method = 'brute'
338+ else :
339+ result_method = self ._fit_method
340+
341+ is_sparse = sp .isspmatrix (data [0 ])
342+ is_single_output = False
343+ class_count = 1
344+ if len (data ) > 1 or hasattr (self , '_onedal_estimator' ):
345+ # To check multioutput, might be overhead
328346 if len (data ) > 1 :
329- class_count = len (np .unique (data [1 ]))
330- # To check multioutput, might be overhead
331347 y = np .asarray (data [1 ])
332- is_single_output = y .ndim == 1 or y .ndim == 2 and y .shape [1 ] == 1
333- is_valid_for_kd_tree = \
334- result_method in ['kd_tree' ] and self .effective_metric_ in ['euclidean' ]
335- is_valid_for_brute = result_method in ['brute' ] and \
336- self .effective_metric_ in ['manhattan' ,
337- 'minkowski' ,
338- 'euclidean' ,
339- 'chebyshev' ,
340- 'cosine' ]
341- return (is_valid_for_kd_tree or is_valid_for_brute ) and \
342- class_count >= 2 and \
343- not is_sparse and \
344- is_single_output
348+ class_count = len (np .unique (y ))
349+ if hasattr (self , '_onedal_estimator' ):
350+ y = self ._onedal_estimator ._y
351+ is_single_output = y .ndim == 1 or y .ndim == 2 and y .shape [1 ] == 1
352+ is_valid_for_kd_tree = \
353+ result_method in ['kd_tree' ] and self .effective_metric_ in ['euclidean' ]
354+ is_valid_for_brute = result_method in ['brute' ] and \
355+ self .effective_metric_ in ['manhattan' ,
356+ 'minkowski' ,
357+ 'euclidean' ,
358+ 'chebyshev' ,
359+ 'cosine' ]
360+ is_valid_weights = self .weights in ['uniform' , "distance" ]
361+ main_condition = (is_valid_for_kd_tree or is_valid_for_brute ) and \
362+ not is_sparse and is_single_output and is_valid_weights
363+
364+ if method_name == 'neighbors.KNeighborsClassifier.fit' :
365+ return main_condition and class_count >= 2
345366 if method_name in ['neighbors.KNeighborsClassifier.predict' ,
346367 'neighbors.KNeighborsClassifier.predict_proba' ,
347368 'neighbors.KNeighborsClassifier.kneighbors' ]:
348- return hasattr (self , '_onedal_estimator' ) and not sp . isspmatrix ( data [ 0 ] )
369+ return main_condition and hasattr (self , '_onedal_estimator' )
349370 raise RuntimeError (f'Unknown method { method_name } in { self .__class__ .__name__ } ' )
350371
351372 def _onedal_fit (self , X , y , queue = None ):
@@ -387,7 +408,6 @@ def _save_attributes(self):
387408 self .n_samples_fit_ = self ._onedal_estimator .n_samples_fit_
388409 self ._fit_X = self ._onedal_estimator ._fit_X
389410 self ._y = self ._onedal_estimator ._y
390- self .shape = self ._onedal_estimator .shape
391411 self ._fit_method = self ._onedal_estimator ._fit_method
392412 self .outputs_2d_ = self ._onedal_estimator .outputs_2d_
393413 self ._tree = self ._onedal_estimator ._tree
0 commit comments