Skip to content

Commit 50c60ba

Browse files
authored
Several fixes for kNN classification and unsupervised (#962)
1 parent 59fee38 commit 50c60ba

File tree

4 files changed

+148
-127
lines changed

4 files changed

+148
-127
lines changed

deselected_tests.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,8 @@ deselected_tests:
209209
# Some sklearnex docstrings differ from scikit-learn.
210210
- tests/test_docstrings.py >=1.0.2
211211

212-
# Results with scikit-learn are different for KNN in new ifaces. Need to fix.
213-
- neighbors/tests/test_neighbors.py::test_KNeighborsClassifier_multioutput
214-
215212
# Temporary deselected up to 2021.6 release. Need to fix
216213
- ensemble/tests/test_bagging.py::test_classification
217-
- neighbors/tests/test_neighbors.py::test_KNeighborsClassifier_multioutput
218-
- tests/test_common.py::test_estimators[KNeighborsClassifier()-check_classifier_data_not_an_array]
219-
- tests/test_common.py::test_estimators[KNeighborsClassifier()-check_dont_overwrite_parameters]
220-
- tests/test_common.py::test_estimators[NearestNeighbors()-check_dont_overwrite_parameters]
221-
- tests/test_common.py::test_pandas_column_name_consistency[KNeighborsClassifier()]
222-
- tests/test_common.py::test_pandas_column_name_consistency[NearestNeighbors()]
223214

224215
# --------------------------------------------------------
225216
# Not need of testing for daal4py patching

onedal/neighbors/neighbors.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ def _validate_n_classes(self):
157157
def _fit(self, X, y, queue):
158158
self._onedal_model = None
159159
self._tree = None
160-
self.shape = None
160+
self._shape = None
161161
self.classes_ = None
162162
self.effective_metric_ = getattr(self, 'effective_metric_', self.metric)
163163
self.effective_metric_params_ = getattr(
164164
self, 'effective_metric_params_', self.metric_params)
165165

166166
if y is not None or self.requires_y:
167167
X, y = super()._validate_data(X, y, dtype=[np.float64, np.float32])
168-
self.shape = y.shape
168+
self._shape = y.shape
169169

170170
if _is_classifier(self) or _is_regressor(self):
171171
if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1:
@@ -186,7 +186,8 @@ def _fit(self, X, y, queue):
186186
if not self.outputs_2d_:
187187
self.classes_ = self.classes_[0]
188188
self._y = self._y.ravel()
189-
self._validate_n_classes()
189+
if _is_classifier(self):
190+
self._validate_n_classes()
190191
else:
191192
self._y = y
192193
else:
@@ -212,12 +213,12 @@ def _fit(self, X, y, queue):
212213
self.algorithm,
213214
self.n_samples_fit_, self.n_features_in_)
214215

215-
if _is_classifier(self) and y.dtype != X.dtype:
216+
if (_is_classifier(self) or _is_regressor(self)) and y.dtype != X.dtype:
216217
y = self._validate_targets(self._y, X.dtype).reshape((-1, 1))
217218
result = self._onedal_fit(X, y, queue)
218219

219220
if y is not None and _is_regressor(self):
220-
self._y = y if self.shape is None else y.reshape(self.shape)
221+
self._y = y if self._shape is None else y.reshape(self._shape)
221222

222223
self._onedal_model = result.model
223224
result = self

sklearnex/neighbors/knn_classification.py

Lines changed: 84 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def __init__(self, n_neighbors=5, *,
110110
n_jobs=n_jobs, **kwargs)
111111

112112
def fit(self, X, y):
113+
if Version(sklearn_version) >= Version("1.0"):
114+
self._check_feature_names(X, reset=True)
113115
if self.metric_params is not None and 'p' in self.metric_params:
114116
if self.p is not None:
115117
warnings.warn("Parameter p is found in metric_params. "
@@ -224,6 +226,8 @@ def fit(self, X, y):
224226
@wrap_output_data
225227
def predict(self, X):
226228
check_is_fitted(self)
229+
if Version(sklearn_version) >= Version("1.0"):
230+
self._check_feature_names(X, reset=False)
227231
return dispatch(self, 'neighbors.KNeighborsClassifier.predict', {
228232
'onedal': self.__class__._onedal_predict,
229233
'sklearn': sklearn_KNeighborsClassifier.predict,
@@ -232,6 +236,8 @@ def predict(self, X):
232236
@wrap_output_data
233237
def predict_proba(self, X):
234238
check_is_fitted(self)
239+
if Version(sklearn_version) >= Version("1.0"):
240+
self._check_feature_names(X, reset=False)
235241
return dispatch(self, 'neighbors.KNeighborsClassifier.predict_proba', {
236242
'onedal': self.__class__._onedal_predict_proba,
237243
'sklearn': sklearn_KNeighborsClassifier.predict_proba,
@@ -240,6 +246,8 @@ def predict_proba(self, X):
240246
@wrap_output_data
241247
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
242248
check_is_fitted(self)
249+
if Version(sklearn_version) >= Version("1.0"):
250+
self._check_feature_names(X, reset=False)
243251
return dispatch(self, 'neighbors.KNeighborsClassifier.kneighbors', {
244252
'onedal': self.__class__._onedal_kneighbors,
245253
'sklearn': sklearn_KNeighborsClassifier.kneighbors,
@@ -267,85 +275,98 @@ def radius_neighbors(self, X=None, radius=None, return_distance=True,
267275

268276
def _onedal_gpu_supported(self, method_name, *data):
269277
X_incorrect_type = isinstance(data[0], (KDTree, BallTree, sklearn_NeighborsBase))
270-
if not X_incorrect_type:
271-
if self._fit_method in ['auto', 'ball_tree']:
272-
condition = self.n_neighbors is not None and \
273-
self.n_neighbors >= self.n_samples_fit_ // 2
274-
if self.n_features_in_ > 11 or condition:
275-
result_method = 'brute'
276-
else:
277-
if self.metric in ['euclidean']:
278-
result_method = 'kd_tree'
279-
else:
280-
result_method = 'brute'
278+
279+
if X_incorrect_type:
280+
return False
281+
282+
if self._fit_method in ['auto', 'ball_tree']:
283+
condition = self.n_neighbors is not None and \
284+
self.n_neighbors >= self.n_samples_fit_ // 2
285+
if self.n_features_in_ > 11 or condition:
286+
result_method = 'brute'
281287
else:
282-
result_method = self._fit_method
283-
if method_name == 'neighbors.KNeighborsClassifier.fit':
284-
if X_incorrect_type:
285-
return False
286-
is_sparse = sp.isspmatrix(data[0])
287-
class_count = None
288+
if self.effective_metric_ in ['euclidean']:
289+
result_method = 'kd_tree'
290+
else:
291+
result_method = 'brute'
292+
else:
293+
result_method = self._fit_method
294+
295+
is_sparse = sp.isspmatrix(data[0])
296+
is_single_output = False
297+
class_count = 1
298+
if len(data) > 1 or hasattr(self, '_onedal_estimator'):
299+
# To check multioutput, might be overhead
288300
if len(data) > 1:
289-
class_count = len(np.unique(data[1]))
290-
# To check multioutput, might be overhead
291301
y = np.asarray(data[1])
292-
is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
293-
return result_method in ['brute'] and \
294-
self.effective_metric_ in ['manhattan',
295-
'minkowski',
296-
'euclidean',
297-
'chebyshev',
298-
'cosine'] and \
299-
class_count >= 2 and \
300-
not is_sparse and \
301-
is_single_output
302+
class_count = len(np.unique(y))
303+
if hasattr(self, '_onedal_estimator'):
304+
y = self._onedal_estimator._y
305+
is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
306+
is_valid_for_brute = result_method in ['brute'] and \
307+
self.effective_metric_ in ['manhattan',
308+
'minkowski',
309+
'euclidean']
310+
is_valid_weights = self.weights in ['uniform', "distance"]
311+
main_condition = is_valid_for_brute and not is_sparse and \
312+
is_single_output and is_valid_weights
313+
314+
if method_name == 'neighbors.KNeighborsClassifier.fit':
315+
return main_condition and class_count >= 2
302316
if method_name in ['neighbors.KNeighborsClassifier.predict',
303317
'neighbors.KNeighborsClassifier.predict_proba',
304318
'neighbors.KNeighborsClassifier.kneighbors']:
305-
return hasattr(self, '_onedal_estimator') and not sp.isspmatrix(data[0])
319+
return main_condition and hasattr(self, '_onedal_estimator')
306320
raise RuntimeError(f'Unknown method {method_name} in {self.__class__.__name__}')
307321

308322
def _onedal_cpu_supported(self, method_name, *data):
309323
X_incorrect_type = isinstance(data[0], (KDTree, BallTree, sklearn_NeighborsBase))
310-
if not X_incorrect_type:
311-
if self._fit_method in ['auto', 'ball_tree']:
312-
condition = self.n_neighbors is not None and \
313-
self.n_neighbors >= self.n_samples_fit_ // 2
314-
if self.n_features_in_ > 11 or condition:
315-
result_method = 'brute'
316-
else:
317-
if self.metric in ['euclidean']:
318-
result_method = 'kd_tree'
319-
else:
320-
result_method = 'brute'
324+
325+
if X_incorrect_type:
326+
return False
327+
328+
if self._fit_method in ['auto', 'ball_tree']:
329+
condition = self.n_neighbors is not None and \
330+
self.n_neighbors >= self.n_samples_fit_ // 2
331+
if self.n_features_in_ > 11 or condition:
332+
result_method = 'brute'
321333
else:
322-
result_method = self._fit_method
323-
if method_name == 'neighbors.KNeighborsClassifier.fit':
324-
if X_incorrect_type:
325-
return False
326-
is_sparse = sp.isspmatrix(data[0])
327-
class_count = None
334+
if self.effective_metric_ in ['euclidean']:
335+
result_method = 'kd_tree'
336+
else:
337+
result_method = 'brute'
338+
else:
339+
result_method = self._fit_method
340+
341+
is_sparse = sp.isspmatrix(data[0])
342+
is_single_output = False
343+
class_count = 1
344+
if len(data) > 1 or hasattr(self, '_onedal_estimator'):
345+
# To check multioutput, might be overhead
328346
if len(data) > 1:
329-
class_count = len(np.unique(data[1]))
330-
# To check multioutput, might be overhead
331347
y = np.asarray(data[1])
332-
is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
333-
is_valid_for_kd_tree = \
334-
result_method in ['kd_tree'] and self.effective_metric_ in ['euclidean']
335-
is_valid_for_brute = result_method in ['brute'] and \
336-
self.effective_metric_ in ['manhattan',
337-
'minkowski',
338-
'euclidean',
339-
'chebyshev',
340-
'cosine']
341-
return (is_valid_for_kd_tree or is_valid_for_brute) and \
342-
class_count >= 2 and \
343-
not is_sparse and \
344-
is_single_output
348+
class_count = len(np.unique(y))
349+
if hasattr(self, '_onedal_estimator'):
350+
y = self._onedal_estimator._y
351+
is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
352+
is_valid_for_kd_tree = \
353+
result_method in ['kd_tree'] and self.effective_metric_ in ['euclidean']
354+
is_valid_for_brute = result_method in ['brute'] and \
355+
self.effective_metric_ in ['manhattan',
356+
'minkowski',
357+
'euclidean',
358+
'chebyshev',
359+
'cosine']
360+
is_valid_weights = self.weights in ['uniform', "distance"]
361+
main_condition = (is_valid_for_kd_tree or is_valid_for_brute) and \
362+
not is_sparse and is_single_output and is_valid_weights
363+
364+
if method_name == 'neighbors.KNeighborsClassifier.fit':
365+
return main_condition and class_count >= 2
345366
if method_name in ['neighbors.KNeighborsClassifier.predict',
346367
'neighbors.KNeighborsClassifier.predict_proba',
347368
'neighbors.KNeighborsClassifier.kneighbors']:
348-
return hasattr(self, '_onedal_estimator') and not sp.isspmatrix(data[0])
369+
return main_condition and hasattr(self, '_onedal_estimator')
349370
raise RuntimeError(f'Unknown method {method_name} in {self.__class__.__name__}')
350371

351372
def _onedal_fit(self, X, y, queue=None):
@@ -387,7 +408,6 @@ def _save_attributes(self):
387408
self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
388409
self._fit_X = self._onedal_estimator._fit_X
389410
self._y = self._onedal_estimator._y
390-
self.shape = self._onedal_estimator.shape
391411
self._fit_method = self._onedal_estimator._fit_method
392412
self.outputs_2d_ = self._onedal_estimator.outputs_2d_
393413
self._tree = self._onedal_estimator._tree

0 commit comments

Comments
 (0)