Skip to content

Commit 3dceb50

Browse files
committed
Implement metric choice in main class
1 parent 85c0474 commit 3dceb50

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

.coverage

0 Bytes
Binary file not shown.

src/radius_clustering/radius_clustering.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _check_symmetric(self, a: np.ndarray, tol: float = 1e-8) -> bool:
102102
return False
103103
return np.allclose(a, a.T, atol=tol)
104104

105-
def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
105+
def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean") -> "RadiusClustering":
106106
"""
107107
Fit the MDS clustering model to the input data.
108108
@@ -130,6 +130,35 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
130130
y : Ignored
131131
Not used, present here for API consistency by convention.
132132
133+
metric : str | callable, optional (default="euclidean")
134+
The metric to use when computing the distance matrix.
135+
The default is "euclidean".
136+
This should be a valid metric string from
137+
`sklearn.metrics.pairwise_distances` or a callable that computes
138+
the distance between two points.
139+
140+
.. note::
141+
The metric parameter *MUST* be a valid metric string from
142+
`sklearn.metrics.pairwise_distances` or a callable that computes
143+
the distance between two points.
144+
Valid metric strings include :
145+
- "euclidean"
146+
- "manhattan"
147+
- "cosine"
148+
- "minkowski"
149+
- and many more supported by scikit-learn.
150+
please refer to the
151+
`sklearn.metrics.pairwise_distances` documentation for a full list.
152+
153+
.. attention::
154+
If the input is a distance matrix, the metric parameter is ignored.
155+
The distance matrix should be symmetric and square.
156+
157+
.. warning::
158+
If the parameter is a callable, it should :
159+
- Accept two 1D arrays as input.
160+
- Return a single float value representing the distance between the two points.
161+
133162
Returns:
134163
--------
135164
self : object
@@ -157,10 +186,13 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
157186

158187
# Create dist and adj matrices
159188
if not self._check_symmetric(self.X_checked_):
160-
dist_mat = pairwise_distances(self.X_checked_, metric="euclidean")
189+
dist_mat = pairwise_distances(self.X_checked_, metric=metric)
161190
else:
162191
dist_mat = self.X_checked_
163-
192+
193+
if not self._check_symmetric(dist_mat):
194+
raise ValueError("Input distance matrix must be symmetric. Got a non-symmetric matrix.")
195+
self.dist_mat_ = dist_mat
164196
if not isinstance(self.radius, (float, int)):
165197
raise ValueError("Radius must be a positive float.")
166198
if self.radius <= 0:
@@ -177,15 +209,14 @@ def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
177209
np.uint32
178210
) # Edges in the adjacency matrix
179211
# uint32 is used to use less memory. Max number of features is 2^32-1
180-
self.dist_mat_ = dist_mat
181212

182213
self._clustering()
183214
self._compute_effective_radius()
184215
self._compute_labels()
185216

186217
return self
187218

188-
def fit_predict(self, X: np.ndarray, y: None = None) -> np.ndarray:
219+
def fit_predict(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean") -> np.ndarray:
189220
"""
190221
Fit the model and return the cluster labels.
191222
@@ -201,6 +232,11 @@ def fit_predict(self, X: np.ndarray, y: None = None) -> np.ndarray:
201232
the distance matrix will be computed.
202233
y : Ignored
203234
Not used, present here for API consistency by convention.
235+
236+
metric : str | callable, optional (default="euclidean")
237+
The metric to use when computing the distance matrix.
238+
The default is "euclidean".
239+
Refer to the `fit` method for more details on valid metrics.
204240
205241
Returns:
206242
--------

tests/test_unit.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from radius_clustering import RadiusClustering
22
import pytest
3+
import numpy as np
34

45
def test_symmetric():
56
"""
67
Test that the RadiusClustering class can handle symmetric distance matrices.
78
"""
8-
import numpy as np
99

1010
# Check 1D array input
1111

@@ -35,12 +35,11 @@ def test_symmetric():
3535
assert not clustering._check_symmetric(X_non_square), "The matrix should not be symmetric."
3636

3737

38-
def test_fit():
38+
def test_fit_distance_matrix():
3939
"""
40-
Test that the RadiusClustering class can fit to a distance matrix and to a feature matrix.
40+
Test that the RadiusClustering class can fit to a distance matrix.
4141
This test checks both the exact and approximate methods of clustering.
4242
"""
43-
import numpy as np
4443

4544
# Create a symmetric distance matrix
4645
X = np.array([[0, 1, 2],
@@ -55,14 +54,27 @@ def test_fit():
5554
assert clustering.nb_edges_ > 0, "There should be edges in the graph."
5655
assert np.array_equal(clustering.X_checked_, clustering.dist_mat_), "X_checked_ should be equal to dist_mat_ because X is a distance matrix."
5756

57+
@pytest.mark.parametrize(
58+
"test_data", [
59+
("euclidean",1.5),
60+
("manhattan", 2.1),
61+
("cosine", 1.0)
62+
]
63+
)
64+
def test_fit_features(test_data):
65+
"""
66+
Test that the RadiusClustering class can fit to feature data.
67+
This test checks both the exact and approximate methods of clustering
68+
and multiple metrics methods.
69+
"""
5870
# Create a feature matrix
5971
X_features = np.array([[0, 1],
6072
[1, 0],
6173
[2, 1]])
74+
metric, radius = test_data
6275

63-
clustering = RadiusClustering(manner="approx", radius=1.5)
64-
clustering.fit(X_features)
65-
76+
clustering = RadiusClustering(manner="approx", radius=radius)
77+
clustering.fit(X_features, metric=metric)
6678
# Check that the labels are assigned correctly
6779
assert len(clustering.labels_) == X_features.shape[0], "Labels length should match number of samples."
6880
assert clustering.nb_edges_ > 0, "There should be edges in the graph."

0 commit comments

Comments
 (0)