@@ -102,7 +102,7 @@ def _check_symmetric(self, a: np.ndarray, tol: float = 1e-8) -> bool:
102102 return False
103103 return np .allclose (a , a .T , atol = tol )
104104
105- def fit (self , X : np .ndarray , y : None = None ) -> "RadiusClustering" :
105+ def fit (self , X : np .ndarray , y : None = None , metric : str | callable = "euclidean" ) -> "RadiusClustering" :
106106 """
107107 Fit the MDS clustering model to the input data.
108108
@@ -130,6 +130,35 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
130130 y : Ignored
131131 Not used, present here for API consistency by convention.
132132
133+ metric : str | callable, optional (default="euclidean")
134+ The metric to use when computing the distance matrix.
135+ The default is "euclidean".
136+ This should be a valid metric string from
137+ `sklearn.metrics.pairwise_distances` or a callable that computes
138+ the distance between two points.
139+
140+ .. note::
141+ The metric parameter *MUST* be a valid metric string from
142+ `sklearn.metrics.pairwise_distances` or a callable that computes
143+ the distance between two points.
144+ Valid metric strings include :
145+ - "euclidean"
146+ - "manhattan"
147+ - "cosine"
148+ - "minkowski"
149+ - and many more supported by scikit-learn.
150+ please refer to the
151+ `sklearn.metrics.pairwise_distances` documentation for a full list.
152+
153+ .. attention::
154+ If the input is a distance matrix, the metric parameter is ignored.
155+ The distance matrix should be symmetric and square.
156+
157+ .. warning::
158+ If the parameter is a callable, it should :
159+ - Accept two 1D arrays as input.
160+ - Return a single float value representing the distance between the two points.
161+
133162 Returns:
134163 --------
135164 self : object
@@ -157,10 +186,13 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
157186
158187 # Create dist and adj matrices
159188 if not self ._check_symmetric (self .X_checked_ ):
160- dist_mat = pairwise_distances (self .X_checked_ , metric = "euclidean" )
189+ dist_mat = pairwise_distances (self .X_checked_ , metric = metric )
161190 else :
162191 dist_mat = self .X_checked_
163-
192+
193+ if not self ._check_symmetric (dist_mat ):
194+ raise ValueError ("Input distance matrix must be symmetric. Got a non-symmetric matrix." )
195+ self .dist_mat_ = dist_mat
164196 if not isinstance (self .radius , (float , int )):
165197 raise ValueError ("Radius must be a positive float." )
166198 if self .radius <= 0 :
@@ -177,15 +209,14 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
177209 np .uint32
178210 ) # Edges in the adjacency matrix
179211 # uint32 is used to use less memory. Max number of features is 2^32-1
180- self .dist_mat_ = dist_mat
181212
182213 self ._clustering ()
183214 self ._compute_effective_radius ()
184215 self ._compute_labels ()
185216
186217 return self
187218
188- def fit_predict (self , X : np .ndarray , y : None = None ) -> np .ndarray :
219+ def fit_predict (self , X : np .ndarray , y : None = None , metric : str | callable = "euclidean" ) -> np .ndarray :
189220 """
190221 Fit the model and return the cluster labels.
191222
@@ -201,6 +232,11 @@ def fit_predict(self, X: np.ndarray, y: None = None) -> np.ndarray:
201232 the distance matrix will be computed.
202233 y : Ignored
203234 Not used, present here for API consistency by convention.
235+
236+ metric : str | callable, optional (default="euclidean")
237+ The metric to use when computing the distance matrix.
238+ The default is "euclidean".
239+ Refer to the `fit` method for more details on valid metrics.
204240
205241 Returns:
206242 --------
0 commit comments