Skip to content

Commit 59675a6

Browse files
FIX: Fix KNN predictions on X=None (#2821)
* fix predictions on X=None * don't fall back on predict_proba * add comment about misleading log * don't deselect valid conformance tests * limit scope of deselections
1 parent dba369d commit 59675a6

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

deselected_tests.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ deselected_tests:
3434
- neighbors/tests/test_neighbors.py::test_nan_euclidean_support[KNeighborsClassifier-params0]
3535
- neighbors/tests/test_neighbors.py::test_nan_euclidean_support[KNeighborsRegressor-params1]
3636
- neighbors/tests/test_neighbors.py::test_nan_euclidean_support[LocalOutlierFactor-params6]
37-
- neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[ball_tree-nn_model0]
38-
- neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[brute-nn_model0]
39-
- neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[kd_tree-nn_model0]
40-
- neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[auto-nn_model0]
4137
# sklearn 1.7 unsupported features
4238
- tests/test_common.py::test_estimators[LinearRegression()-check_sample_weight_equivalence_on_dense_data]
4339

onedal/neighbors/neighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def predict_proba(self, X, queue=None):
472472
_y = self._y.reshape((-1, 1))
473473
classes_ = [self.classes_]
474474

475-
n_queries = _num_samples(X)
475+
n_queries = _num_samples(X if X is not None else self._fit_X)
476476

477477
weights = self._get_weights(neigh_dist, self.weights)
478478
if weights is None:

sklearnex/neighbors/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@ def _onedal_supported(self, device, method_name, *data):
156156
patching_status = PatchingConditionsChain(
157157
f"sklearn.neighbors.{class_name}.{method_name}"
158158
)
159+
# TODO: with verbosity enabled, here it would emit a log saying that it fell
160+
# back to sklearn, but internally, sklearn will end up calling 'kneighbors'
161+
# which is overridden in the sklearnex classes, thus it will end up calling
162+
# oneDAL in the end, but the log will say otherwise. Find a way to make the
163+
# log consistent with what happens in practice.
164+
patching_status.and_conditions(
165+
[
166+
(
167+
not (data[0] is None and method_name in ["predict", "score"]),
168+
"Predictions on 'None' data are handled by internal sklearn methods.",
169+
)
170+
]
171+
)
159172
if not patching_status.and_condition(
160173
"radius" not in method_name, "RadiusNeighbors not implemented in sklearnex"
161174
):

sklearnex/neighbors/knn_classification.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def fit(self, X, y):
7979
@wrap_output_data
8080
def predict(self, X):
8181
check_is_fitted(self)
82-
check_feature_names(self, X, reset=False)
82+
if X is not None:
83+
check_feature_names(self, X, reset=False)
8384
return dispatch(
8485
self,
8586
"predict",
@@ -93,7 +94,8 @@ def predict(self, X):
9394
@wrap_output_data
9495
def predict_proba(self, X):
9596
check_is_fitted(self)
96-
check_feature_names(self, X, reset=False)
97+
if X is not None:
98+
check_feature_names(self, X, reset=False)
9799
return dispatch(
98100
self,
99101
"predict_proba",
@@ -107,7 +109,8 @@ def predict_proba(self, X):
107109
@wrap_output_data
108110
def score(self, X, y, sample_weight=None):
109111
check_is_fitted(self)
110-
check_feature_names(self, X, reset=False)
112+
if X is not None:
113+
check_feature_names(self, X, reset=False)
111114
return dispatch(
112115
self,
113116
"score",

0 commit comments

Comments
 (0)