Skip to content

Commit d17bb34

Browse files
committed
fix: try it again
1 parent 325753c commit d17bb34

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

sklearnex/neighbors/common.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,81 @@
3131
from .._utils import PatchingConditionsChain
3232
from ..base import oneDALEstimator
3333
from ..utils._array_api import get_namespace
34-
from ..utils.validation import check_feature_names
3534

3635

3736
class KNeighborsDispatchingBase(oneDALEstimator):
37+
38+
def _parse_auto_method(self, method, n_samples, n_features):
39+
"""Parse auto method selection for neighbors algorithm."""
40+
result_method = method
41+
42+
if method in ["auto", "ball_tree"]:
43+
condition = (
44+
self.n_neighbors is not None and self.n_neighbors >= n_samples // 2
45+
)
46+
if self.metric == "precomputed" or n_features > 15 or condition:
47+
result_method = "brute"
48+
else:
49+
if self.metric == "euclidean":
50+
result_method = "kd_tree"
51+
else:
52+
result_method = "brute"
53+
54+
return result_method
55+
56+
def _get_weights(self, dist, weights):
57+
"""Get weights for neighbors based on distance and weights parameter."""
58+
if weights in (None, "uniform"):
59+
return None
60+
if weights == "distance":
61+
# if user attempts to classify a point that was zero distance from one
62+
# or more training points, those training points are weighted as 1.0
63+
# and the other points as 0.0
64+
if dist.dtype is np.dtype(object):
65+
for point_dist_i, point_dist in enumerate(dist):
66+
# check if point_dist is iterable
67+
# (ex: RadiusNeighborClassifier.predict may set an element of
68+
# dist to 1e-6 to represent an 'outlier')
69+
if hasattr(point_dist, "__contains__") and 0.0 in point_dist:
70+
dist[point_dist_i] = point_dist == 0.0
71+
else:
72+
dist[point_dist_i] = 1.0 / point_dist
73+
else:
74+
with np.errstate(divide="ignore"):
75+
dist = 1.0 / dist
76+
inf_mask = np.isinf(dist)
77+
inf_row = np.any(inf_mask, axis=1)
78+
dist[inf_row] = inf_mask[inf_row]
79+
return dist
80+
elif callable(weights):
81+
return weights(dist)
82+
else:
83+
raise ValueError(
84+
"weights not recognized: should be 'uniform', "
85+
"'distance', or a callable function"
86+
)
87+
88+
def _validate_targets(self, y, dtype):
89+
"""Validate and convert target values."""
90+
from onedal.utils.validation import _column_or_1d
91+
arr = _column_or_1d(y, warn=True)
92+
93+
try:
94+
return arr.astype(dtype, copy=False)
95+
except ValueError:
96+
return arr
97+
98+
def _validate_n_classes(self):
99+
"""Validate that we have at least 2 classes for classification."""
100+
length = 0 if self.classes_ is None else len(self.classes_)
101+
if length < 2:
102+
raise ValueError(
103+
f"The number of classes has to be greater than one; got {length}"
104+
)
38105
def _fit_validation(self, X, y=None):
39106
if sklearn_check_version("1.2"):
40107
self._validate_params()
41-
check_feature_names(self, X, reset=True)
108+
42109
if self.metric_params is not None and "p" in self.metric_params:
43110
if self.p is not None:
44111
warnings.warn(
@@ -67,8 +134,9 @@ def _fit_validation(self, X, y=None):
67134
self.effective_metric_ = "chebyshev"
68135

69136
if not isinstance(X, (KDTree, BallTree, _sklearn_NeighborsBase)):
137+
xp, _ = get_namespace(X)
70138
self._fit_X = _check_array(
71-
X, dtype=[np.float64, np.float32], accept_sparse=True
139+
X, dtype=[xp.float64, xp.float32], accept_sparse=True
72140
)
73141
self.n_samples_fit_ = _num_samples(self._fit_X)
74142
self.n_features_in_ = _num_features(self._fit_X)
@@ -310,4 +378,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
310378

311379
return kneighbors_graph
312380

313-
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__
381+
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__

sklearnex/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,4 +601,4 @@ def test_estimator(estimator, method, design_pattern, estimator_trace):
601601
if key in _DESIGN_RULE_VIOLATIONS:
602602
pytest.xfail(_DESIGN_RULE_VIOLATIONS[key])
603603
else:
604-
raise
604+
raise

0 commit comments

Comments
 (0)