3535
3636
3737class KNeighborsDispatchingBase (oneDALEstimator ):
38-
39- def _parse_auto_method (self , method , n_samples , n_features ):
40- """Parse auto method selection for neighbors algorithm."""
41- result_method = method
42-
43- if method in ["auto" , "ball_tree" ]:
44- condition = (
45- self .n_neighbors is not None and self .n_neighbors >= n_samples // 2
46- )
47- if self .metric == "precomputed" or n_features > 15 or condition :
48- result_method = "brute"
49- else :
50- if self .metric == "euclidean" :
51- result_method = "kd_tree"
52- else :
53- result_method = "brute"
54-
55- return result_method
56-
57- def _get_weights (self , dist , weights ):
58- """Get weights for neighbors based on distance and weights parameter."""
59- if weights in (None , "uniform" ):
60- return None
61- if weights == "distance" :
62- # if user attempts to classify a point that was zero distance from one
63- # or more training points, those training points are weighted as 1.0
64- # and the other points as 0.0
65- if dist .dtype is np .dtype (object ):
66- for point_dist_i , point_dist in enumerate (dist ):
67- # check if point_dist is iterable
68- # (ex: RadiusNeighborClassifier.predict may set an element of
69- # dist to 1e-6 to represent an 'outlier')
70- if hasattr (point_dist , "__contains__" ) and 0.0 in point_dist :
71- dist [point_dist_i ] = point_dist == 0.0
72- else :
73- dist [point_dist_i ] = 1.0 / point_dist
74- else :
75- with np .errstate (divide = "ignore" ):
76- dist = 1.0 / dist
77- inf_mask = np .isinf (dist )
78- inf_row = np .any (inf_mask , axis = 1 )
79- dist [inf_row ] = inf_mask [inf_row ]
80- return dist
81- elif callable (weights ):
82- return weights (dist )
83- else :
84- raise ValueError (
85- "weights not recognized: should be 'uniform', "
86- "'distance', or a callable function"
87- )
88-
89- def _validate_targets (self , y , dtype ):
90- """Validate and convert target values."""
91- from onedal .utils .validation import _column_or_1d
92- arr = _column_or_1d (y , warn = True )
93-
94- try :
95- return arr .astype (dtype , copy = False )
96- except ValueError :
97- return arr
98-
99- def _validate_n_classes (self ):
100- """Validate that we have at least 2 classes for classification."""
101- length = 0 if self .classes_ is None else len (self .classes_ )
102- if length < 2 :
103- raise ValueError (
104- f"The number of classes has to be greater than one; got { length } "
105- )
10638 def _fit_validation (self , X , y = None ):
10739 if sklearn_check_version ("1.2" ):
10840 self ._validate_params ()
@@ -378,4 +310,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
378310
379311 return kneighbors_graph
380312
381- kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
313+ kneighbors_graph .__doc__ = KNeighborsMixin .kneighbors_graph .__doc__
0 commit comments