@@ -102,7 +102,7 @@ def _check_symmetric(self, a: np.ndarray, tol: float = 1e-8) -> bool:
102
102
return False
103
103
return np .allclose (a , a .T , atol = tol )
104
104
105
- def fit (self , X : np .ndarray , y : None = None ) -> "RadiusClustering" :
105
+ def fit (self , X : np .ndarray , y : None = None , metric : str | callable = "euclidean" ) -> "RadiusClustering" :
106
106
"""
107
107
Fit the MDS clustering model to the input data.
108
108
@@ -130,6 +130,35 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
130
130
y : Ignored
131
131
Not used, present here for API consistency by convention.
132
132
133
+ metric : str | callable, optional (default="euclidean")
134
+ The metric to use when computing the distance matrix.
135
+ The default is "euclidean".
136
+ This should be a valid metric string from
137
+ `sklearn.metrics.pairwise_distances` or a callable that computes
138
+ the distance between two points.
139
+
140
+ .. note::
141
+ The metric parameter *MUST* be a valid metric string from
142
+ `sklearn.metrics.pairwise_distances` or a callable that computes
143
+ the distance between two points.
144
+ Valid metric strings include :
145
+ - "euclidean"
146
+ - "manhattan"
147
+ - "cosine"
148
+ - "minkowski"
149
+ - and many more supported by scikit-learn.
150
+ please refer to the
151
+ `sklearn.metrics.pairwise_distances` documentation for a full list.
152
+
153
+ .. attention::
154
+ If the input is a distance matrix, the metric parameter is ignored.
155
+ The distance matrix should be symmetric and square.
156
+
157
+ .. warning::
158
+ If the parameter is a callable, it should :
159
+ - Accept two 1D arrays as input.
160
+ - Return a single float value representing the distance between the two points.
161
+
133
162
Returns:
134
163
--------
135
164
self : object
@@ -157,10 +186,13 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
157
186
158
187
# Create dist and adj matrices
159
188
if not self ._check_symmetric (self .X_checked_ ):
160
- dist_mat = pairwise_distances (self .X_checked_ , metric = "euclidean" )
189
+ dist_mat = pairwise_distances (self .X_checked_ , metric = metric )
161
190
else :
162
191
dist_mat = self .X_checked_
163
-
192
+
193
+ if not self ._check_symmetric (dist_mat ):
194
+ raise ValueError ("Input distance matrix must be symmetric. Got a non-symmetric matrix." )
195
+ self .dist_mat_ = dist_mat
164
196
if not isinstance (self .radius , (float , int )):
165
197
raise ValueError ("Radius must be a positive float." )
166
198
if self .radius <= 0 :
@@ -177,15 +209,14 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
177
209
np .uint32
178
210
) # Edges in the adjacency matrix
179
211
# uint32 is used to use less memory. Max number of features is 2^32-1
180
- self .dist_mat_ = dist_mat
181
212
182
213
self ._clustering ()
183
214
self ._compute_effective_radius ()
184
215
self ._compute_labels ()
185
216
186
217
return self
187
218
188
- def fit_predict (self , X : np .ndarray , y : None = None ) -> np .ndarray :
219
+ def fit_predict (self , X : np .ndarray , y : None = None , metric : str | callable = "euclidean" ) -> np .ndarray :
189
220
"""
190
221
Fit the model and return the cluster labels.
191
222
@@ -201,6 +232,11 @@ def fit_predict(self, X: np.ndarray, y: None = None) -> np.ndarray:
201
232
the distance matrix will be computed.
202
233
y : Ignored
203
234
Not used, present here for API consistency by convention.
235
+
236
+ metric : str | callable, optional (default="euclidean")
237
+ The metric to use when computing the distance matrix.
238
+ The default is "euclidean".
239
+ Refer to the `fit` method for more details on valid metrics.
204
240
205
241
Returns:
206
242
--------
0 commit comments