Skip to content

Commit 82fc26d

Browse files
committed
refactor: move/delete some methods in neighbors.py
1 parent 7317aef commit 82fc26d

File tree

4 files changed

+101
-11
lines changed

4 files changed

+101
-11
lines changed

sklearnex/neighbors/common.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,74 @@
3535

3636

3737
class KNeighborsDispatchingBase(oneDALEstimator):
38+
39+
def _parse_auto_method(self, method, n_samples, n_features):
40+
"""Parse auto method selection for neighbors algorithm."""
41+
result_method = method
42+
43+
if method in ["auto", "ball_tree"]:
44+
condition = (
45+
self.n_neighbors is not None and self.n_neighbors >= n_samples // 2
46+
)
47+
if self.metric == "precomputed" or n_features > 15 or condition:
48+
result_method = "brute"
49+
else:
50+
if self.metric == "euclidean":
51+
result_method = "kd_tree"
52+
else:
53+
result_method = "brute"
54+
55+
return result_method
56+
57+
def _get_weights(self, dist, weights):
58+
"""Get weights for neighbors based on distance and weights parameter."""
59+
if weights in (None, "uniform"):
60+
return None
61+
if weights == "distance":
62+
# if user attempts to classify a point that was zero distance from one
63+
# or more training points, those training points are weighted as 1.0
64+
# and the other points as 0.0
65+
if dist.dtype is np.dtype(object):
66+
for point_dist_i, point_dist in enumerate(dist):
67+
# check if point_dist is iterable
68+
# (ex: RadiusNeighborClassifier.predict may set an element of
69+
# dist to 1e-6 to represent an 'outlier')
70+
if hasattr(point_dist, "__contains__") and 0.0 in point_dist:
71+
dist[point_dist_i] = point_dist == 0.0
72+
else:
73+
dist[point_dist_i] = 1.0 / point_dist
74+
else:
75+
with np.errstate(divide="ignore"):
76+
dist = 1.0 / dist
77+
inf_mask = np.isinf(dist)
78+
inf_row = np.any(inf_mask, axis=1)
79+
dist[inf_row] = inf_mask[inf_row]
80+
return dist
81+
elif callable(weights):
82+
return weights(dist)
83+
else:
84+
raise ValueError(
85+
"weights not recognized: should be 'uniform', "
86+
"'distance', or a callable function"
87+
)
88+
89+
def _validate_targets(self, y, dtype):
90+
"""Validate and convert target values."""
91+
from onedal.utils.validation import _column_or_1d
92+
arr = _column_or_1d(y, warn=True)
93+
94+
try:
95+
return arr.astype(dtype, copy=False)
96+
except ValueError:
97+
return arr
98+
99+
def _validate_n_classes(self):
100+
"""Validate that we have at least 2 classes for classification."""
101+
length = 0 if self.classes_ is None else len(self.classes_)
102+
if length < 2:
103+
raise ValueError(
104+
f"The number of classes has to be greater than one; got {length}"
105+
)
38106
def _fit_validation(self, X, y=None):
39107
if sklearn_check_version("1.2"):
40108
self._validate_params()

sklearnex/neighbors/knn_classification.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import numpy as np
1718
from sklearn.metrics import accuracy_score
1819
from sklearn.neighbors._classification import (
1920
KNeighborsClassifier as _sklearn_KNeighborsClassifier,
@@ -24,6 +25,8 @@
2425
from daal4py.sklearn._utils import sklearn_check_version
2526
from daal4py.sklearn.utils.validation import get_requires_y_tag
2627
from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier
28+
from onedal.utils.validation import _check_X_y, _check_classification_targets, _check_n_features
29+
from onedal.common._estimator_checks import _is_classifier
2730

2831
from .._device_offload import dispatch, wrap_output_data
2932
from ..utils.validation import check_feature_names
@@ -141,16 +144,20 @@ def _onedal_fit(self, X, y, queue=None):
141144
onedal_params = {
142145
"n_neighbors": self.n_neighbors,
143146
"weights": self.weights,
144-
"algorithm": self.algorithm,
147+
"algorithm": self._fit_method, # Use parsed method
145148
"metric": self.effective_metric_,
146-
"p": self.effective_metric_params_["p"],
149+
"p": self.effective_metric_params_["p"] if self.effective_metric_params_ else 2,
147150
}
148151

149152
self._onedal_estimator = onedal_KNeighborsClassifier(**onedal_params)
150-
self._onedal_estimator.requires_y = get_requires_y_tag(self)
151153
self._onedal_estimator.effective_metric_ = self.effective_metric_
152154
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
153-
self._onedal_estimator.fit(X, y, queue=queue)
155+
self._onedal_estimator._fit_method = self._fit_method
156+
self._onedal_estimator.classes_ = self.classes_
157+
158+
# Prepare y for onedal
159+
fit_y = self._validate_targets(processed_y, X.dtype).reshape((-1, 1))
160+
self._onedal_estimator.fit(X, fit_y, queue=queue)
154161

155162
self._save_attributes()
156163

sklearnex/neighbors/knn_regression.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17+
import numpy as np
1718
from sklearn.metrics import r2_score
1819
from sklearn.neighbors._regression import (
1920
KNeighborsRegressor as _sklearn_KNeighborsRegressor,
@@ -24,6 +25,8 @@
2425
from daal4py.sklearn._utils import sklearn_check_version
2526
from daal4py.sklearn.utils.validation import get_requires_y_tag
2627
from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor
28+
from onedal.utils.validation import _check_X_y, _check_n_features
29+
from onedal.common._estimator_checks import _is_regressor
2730

2831
from .._device_offload import dispatch, wrap_output_data
2932
from ..utils.validation import check_feature_names
@@ -125,16 +128,23 @@ def _onedal_fit(self, X, y, queue=None):
125128
onedal_params = {
126129
"n_neighbors": self.n_neighbors,
127130
"weights": self.weights,
128-
"algorithm": self.algorithm,
131+
"algorithm": self._fit_method, # Use parsed method
129132
"metric": self.effective_metric_,
130-
"p": self.effective_metric_params_["p"],
133+
"p": self.effective_metric_params_["p"] if self.effective_metric_params_ else 2,
131134
}
132135

133136
self._onedal_estimator = onedal_KNeighborsRegressor(**onedal_params)
134-
self._onedal_estimator.requires_y = get_requires_y_tag(self)
135137
self._onedal_estimator.effective_metric_ = self.effective_metric_
136138
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
137-
self._onedal_estimator.fit(X, y, queue=queue)
139+
self._onedal_estimator._fit_method = self._fit_method
140+
141+
# For regression, prepare y data
142+
fit_y = self._validate_targets(y, X.dtype).reshape((-1, 1))
143+
self._onedal_estimator.fit(X, fit_y, queue=queue)
144+
145+
# Reshape y back if needed
146+
if self._shape is not None:
147+
self._y = np.reshape(y, self._shape)
138148

139149
self._save_attributes()
140150

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import numpy as np
1718
from sklearn.neighbors._unsupervised import NearestNeighbors as _sklearn_NearestNeighbors
1819
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
1920

2021
from daal4py.sklearn._n_jobs_support import control_n_jobs
2122
from daal4py.sklearn._utils import sklearn_check_version
2223
from daal4py.sklearn.utils.validation import get_requires_y_tag
2324
from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors
25+
from onedal.utils.validation import _check_array, _check_n_features
2426

2527
from .._device_offload import dispatch, wrap_output_data
2628
from ..utils.validation import check_feature_names
@@ -131,15 +133,18 @@ def radius_neighbors_graph(
131133
def _onedal_fit(self, X, y=None, queue=None):
132134
onedal_params = {
133135
"n_neighbors": self.n_neighbors,
134-
"algorithm": self.algorithm,
136+
"algorithm": self._fit_method, # Use parsed method
135137
"metric": self.effective_metric_,
136-
"p": self.effective_metric_params_["p"],
138+
"p": self.effective_metric_params_["p"] if self.effective_metric_params_ else 2,
137139
}
138140

139141
self._onedal_estimator = onedal_NearestNeighbors(**onedal_params)
140-
self._onedal_estimator.requires_y = get_requires_y_tag(self)
141142
self._onedal_estimator.effective_metric_ = self.effective_metric_
142143
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
144+
self._onedal_estimator._fit_method = self._fit_method
145+
self._onedal_estimator.fit(X, y, queue=queue)
146+
147+
self._save_attributes()
143148
self._onedal_estimator.fit(X, y, queue=queue)
144149

145150
self._save_attributes()

0 commit comments

Comments
 (0)