3131from .._utils import PatchingConditionsChain
3232from ..base import oneDALEstimator
3333from ..utils ._array_api import get_namespace
34- from ..utils .validation import check_feature_names
3534
3635
3736class KNeighborsDispatchingBase (oneDALEstimator ):
37+
38+ def _parse_auto_method (self , method , n_samples , n_features ):
39+ """Parse auto method selection for neighbors algorithm."""
40+ result_method = method
41+
42+ if method in ["auto" , "ball_tree" ]:
43+ condition = (
44+ self .n_neighbors is not None and self .n_neighbors >= n_samples // 2
45+ )
46+ if self .metric == "precomputed" or n_features > 15 or condition :
47+ result_method = "brute"
48+ else :
49+ if self .metric == "euclidean" :
50+ result_method = "kd_tree"
51+ else :
52+ result_method = "brute"
53+
54+ return result_method
55+
56+ def _get_weights (self , dist , weights ):
57+ """Get weights for neighbors based on distance and weights parameter."""
58+ if weights in (None , "uniform" ):
59+ return None
60+ if weights == "distance" :
61+ # if user attempts to classify a point that was zero distance from one
62+ # or more training points, those training points are weighted as 1.0
63+ # and the other points as 0.0
64+ if dist .dtype is np .dtype (object ):
65+ for point_dist_i , point_dist in enumerate (dist ):
66+ # check if point_dist is iterable
67+ # (ex: RadiusNeighborClassifier.predict may set an element of
68+ # dist to 1e-6 to represent an 'outlier')
69+ if hasattr (point_dist , "__contains__" ) and 0.0 in point_dist :
70+ dist [point_dist_i ] = point_dist == 0.0
71+ else :
72+ dist [point_dist_i ] = 1.0 / point_dist
73+ else :
74+ with np .errstate (divide = "ignore" ):
75+ dist = 1.0 / dist
76+ inf_mask = np .isinf (dist )
77+ inf_row = np .any (inf_mask , axis = 1 )
78+ dist [inf_row ] = inf_mask [inf_row ]
79+ return dist
80+ elif callable (weights ):
81+ return weights (dist )
82+ else :
83+ raise ValueError (
84+ "weights not recognized: should be 'uniform', "
85+ "'distance', or a callable function"
86+ )
87+
88+ def _validate_targets (self , y , dtype ):
89+ """Validate and convert target values."""
90+ from onedal .utils .validation import _column_or_1d
91+ arr = _column_or_1d (y , warn = True )
92+
93+ try :
94+ return arr .astype (dtype , copy = False )
95+ except ValueError :
96+ return arr
97+
98+ def _validate_n_classes (self ):
99+ """Validate that we have at least 2 classes for classification."""
100+ length = 0 if self .classes_ is None else len (self .classes_ )
101+ if length < 2 :
102+ raise ValueError (
103+ f"The number of classes has to be greater than one; got { length } "
104+ )
38105 def _fit_validation (self , X , y = None ):
39106 if sklearn_check_version ("1.2" ):
40107 self ._validate_params ()
41- check_feature_names ( self , X , reset = True )
108+
42109 if self .metric_params is not None and "p" in self .metric_params :
43110 if self .p is not None :
44111 warnings .warn (
@@ -67,8 +134,9 @@ def _fit_validation(self, X, y=None):
67134 self .effective_metric_ = "chebyshev"
68135
69136 if not isinstance (X , (KDTree , BallTree , _sklearn_NeighborsBase )):
137+ xp , _ = get_namespace (X )
70138 self ._fit_X = _check_array (
71- X , dtype = [np .float64 , np .float32 ], accept_sparse = True
139+ X , dtype = [xp .float64 , xp .float32 ], accept_sparse = True
72140 )
73141 self .n_samples_fit_ = _num_samples (self ._fit_X )
74142 self .n_features_in_ = _num_features (self ._fit_X )
@@ -310,4 +378,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
310378
311379 return kneighbors_graph
312380
313- kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
381+ kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
0 commit comments