Skip to content

Commit c7447ef

Browse files
committed
adding support for custom solvers
1 parent 8c605e1 commit c7447ef

File tree

5 files changed

+159
-16
lines changed

5 files changed

+159
-16
lines changed

.coverage

0 Bytes
Binary file not shown.

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ Radius clustering is a Python package that implements clustering under radius co
1818
- Supports radius-constrained clustering
1919
- Provides options for exact and approximate solutions
2020

21+
## Roadmap
22+
23+
- [ ] Version 2.0.0beta:
24+
- [x] Add support for custom MDS solvers
25+
- [ ] Improve documentation and examples
26+
- [ ] Add iterative algorithm in both exact and approximate versions, which will allow to use the package in a more flexible way, especially when not knowing the radius beforehand.
27+
- [ ] Add more examples and tutorials
28+
2129
## Installation
2230

2331
You can install Radius Clustering using pip:

src/radius_clustering/algorithms.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,28 @@ def clustering_approx(
4545
This ensures that the seed passed to the C++ code is always an integer,
4646
which is required by the MDS solver, and allows for
4747
reproducibility of the results.
48+
49+
.. note::
50+
This function uses the approximation method to solve the MDS problem.
51+
See [casado]_ for more details.
4852
4953
Parameters:
5054
-----------
5155
n : int
5256
The number of points in the dataset.
53-
54-
Notes:
55-
------
56-
This function uses the approximation method to solve the MDS problem.
57-
See [casado]_ for more details.
57+
edges : np.ndarray
58+
The edges of the graph, flattened into a 1D array.
59+
nb_edges : int
60+
The number of edges in the graph.
61+
random_state : int | None
62+
The random state to use for reproducibility.
63+
If None, a default value of 42 is used.
64+
Returns:
65+
--------
66+
centers : list
67+
A sorted list of the centers of the clusters.
68+
mds_exec_time : float
69+
The execution time of the MDS algorithm in seconds.
5870
"""
5971
result = solve_mds(
6072
n, edges.flatten().astype(np.int32), nb_edges, random_state
@@ -63,7 +75,7 @@ def clustering_approx(
6375
mds_exec_time = result["Time"]
6476
return centers, mds_exec_time
6577

66-
def clustering_exact(n: int, edges: np.ndarray, nb_edges: int) -> None:
78+
def clustering_exact(n: int, edges: np.ndarray, nb_edges: int, seed: None = None) -> None:
6779
"""
6880
Perform exact MDS clustering.
6981
@@ -78,6 +90,20 @@ def clustering_exact(n: int, edges: np.ndarray, nb_edges: int) -> None:
7890
-----------
7991
n : int
8092
The number of points in the dataset.
93+
edges : np.ndarray
94+
The edges of the graph, flattened into a 1D array.
95+
nb_edges : int
96+
The number of edges in the graph.
97+
seed : None
98+
This parameter is not used in the exact method, but it is kept for
99+
compatibility with the approximate method.
100+
101+
Returns:
102+
--------
103+
centers : list
104+
A sorted list of the centers of the clusters.
105+
mds_exec_time : float
106+
The execution time of the MDS algorithm in seconds.
81107
"""
82108
centers, mds_exec_time = py_emos_main(
83109
edges.flatten(), n, nb_edges

src/radius_clustering/radius_clustering.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class RadiusClustering(ClusterMixin, BaseEstimator):
7575
"""
7676

7777
_estimator_type = "clusterer"
78+
_algorithms = {
79+
"exact": clustering_exact,
80+
"approx": clustering_approx,
81+
}
7882

7983
def __init__(
8084
self,
@@ -211,7 +215,7 @@ def fit(self, X: np.ndarray, y: None = None, metric: str | callable = "euclidean
211215
np.uint32
212216
) # Edges in the adjacency matrix
213217
# uint32 is used to use less memory. Max number of features is 2^32-1
214-
218+
self.clusterer_ = self._algorithms.get(self.manner, self._algorithms["approx"])
215219
self._clustering()
216220
self._compute_effective_radius()
217221
self._compute_labels()
@@ -253,16 +257,16 @@ def _clustering(self):
253257
Perform the clustering using either the exact or approximate MDS method.
254258
"""
255259
n = self.X_checked_.shape[0]
256-
if self.manner != "exact" and self.manner != "approx":
257-
raise ValueError("Invalid manner. Choose either 'exact' or 'approx'.")
258-
if self.manner == "exact":
259-
self.centers_, self.mds_exec_time_ = clustering_exact(n, self.edges_, self.nb_edges_)
260-
else:
260+
if self.manner not in self._algorithms:
261+
raise ValueError(f"Invalid manner. Please choose in {list(self._algorithms.keys())}.")
262+
if self.clusterer_ == clustering_approx:
261263
if self.random_state is None:
262264
self.random_state = 42
263265
self.random_state_ = check_random_state(self.random_state)
264266
seed = self.random_state_.randint(np.iinfo(np.int32).max)
265-
self.centers_, self.mds_exec_time_ = clustering_approx(n, self.edges_, self.nb_edges_, seed)
267+
else:
268+
seed = None
269+
self.centers_, self.mds_exec_time_ = self.clusterer_(n, self.edges_, self.nb_edges_, seed)
266270

267271
def _compute_effective_radius(self):
268272
"""
@@ -282,3 +286,52 @@ def _compute_labels(self):
282286

283287
min_dist = np.min(distances, axis=1)
284288
self.labels_[min_dist > self.radius] = -1
289+
290+
def set_solver(self, solver: callable) -> None:
291+
"""
292+
Set a custom solver for resolving the MDS problem.
293+
This method allows users to replace the default MDS solver with a custom one.
294+
295+
.. important::
296+
The custom solver must accept the same parameters as the default solvers
297+
and return a tuple containing the cluster centers and the execution time.
298+
e.g., it should have the signature:
299+
```python
300+
def custom_solver(
301+
n: int,
302+
edges: np.ndarray,
303+
nb_edges: int,
304+
random_state: int | None = None
305+
) -> tuple[list, float]:
306+
# Custom implementation details
307+
centers = [...]
308+
exec_time = ...
309+
# Return the centers and execution time
310+
return centers, exec_time
311+
```
312+
313+
Parameters:
314+
----------
315+
solver : callable
316+
The custom solver function to use for MDS clustering.
317+
It should accept the same parameters as the default solvers
318+
and return a tuple containing the cluster centers and the execution time.
319+
320+
Raises:
321+
-------
322+
ValueError
323+
If the provided solver does not have the correct signature.
324+
"""
325+
if not callable(solver):
326+
raise ValueError("The provided solver must be callable.")
327+
328+
# Check if the solver has the correct signature
329+
try:
330+
n = 3
331+
edges = np.array([[0, 1], [1, 2], [2, 0]])
332+
nb_edges = edges.shape[0]
333+
solver(n, edges, nb_edges, random_state=None)
334+
except Exception as e:
335+
raise ValueError(f"The provided solver does not have the correct signature: {e}") from e
336+
self.manner = "custom"
337+
self._algorithms["custom"] = solver

tests/test_unit.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def test_radius_clustering_invalid_manner():
8484
"""
8585
Test that an error is raised when an invalid manner is provided.
8686
"""
87-
with pytest.raises(ValueError, match="Invalid manner. Choose either 'exact' or 'approx'."):
87+
with pytest.raises(ValueError):
8888
RadiusClustering(manner="invalid", radius=1.43).fit([[0, 1], [1, 0], [2, 1]])
8989

90-
with pytest.raises(ValueError, match="Invalid manner. Choose either 'exact' or 'approx'."):
90+
with pytest.raises(ValueError):
9191
RadiusClustering(manner="", radius=1.43).fit([[0, 1], [1, 0], [2, 1]])
9292

9393

@@ -102,4 +102,60 @@ def test_radius_clustering_invalid_radius():
102102
RadiusClustering(manner="approx", radius=0.0).fit([[0, 1], [1, 0], [2, 1]])
103103

104104
with pytest.raises(ValueError, match="Radius must be a positive float."):
105-
RadiusClustering(manner="exact", radius="invalid").fit([[0, 1], [1, 0], [2, 1]])
105+
RadiusClustering(manner="exact", radius="invalid").fit([[0, 1], [1, 0], [2, 1]])
106+
107+
def test_radius_clustering_fit_without_data():
108+
"""
109+
Test that an error is raised when fitting without data.
110+
"""
111+
clustering = RadiusClustering(manner="exact", radius=1.5)
112+
with pytest.raises(ValueError):
113+
clustering.fit(None)
114+
115+
def test_radius_clustering_new_clusterer():
116+
"""
117+
Test that a custom clusterer can be set within the RadiusClustering class.
118+
"""
119+
def custom_clusterer(n, edges, nb_edges, random_state=None):
120+
# A mock custom clusterer that returns a fixed set of centers
121+
# and a fixed execution time
122+
return [0, 1], 0.1
123+
clustering = RadiusClustering(manner="exact", radius=1.5)
124+
# Set the custom clusterer
125+
assert hasattr(clustering, 'set_solver'), "RadiusClustering should have a set_solver method."
126+
assert callable(clustering.set_solver), "set_solver should be callable."
127+
clustering.set_solver(custom_clusterer)
128+
# Fit the clustering with the custom clusterer
129+
X = np.array([[0, 1],
130+
[1, 0],
131+
[2, 1]])
132+
clustering.fit(X)
133+
assert clustering.clusterer_ == custom_clusterer, "The custom clusterer should be set correctly."
134+
# Check that the labels are assigned correctly
135+
assert len(clustering.labels_) == X.shape[0], "Labels length should match number of samples."
136+
assert clustering.nb_edges_ > 0, "There should be edges in the graph."
137+
assert clustering.centers_ == [0, 1], "The centers should match the custom clusterer's output."
138+
assert clustering.mds_exec_time_ == 0.1, "The MDS execution time should match the custom clusterer's output."
139+
140+
def test_invalid_clusterer():
141+
"""
142+
Test that an error is raised when an invalid clusterer is set.
143+
"""
144+
clustering = RadiusClustering(manner="exact", radius=1.5)
145+
with pytest.raises(ValueError, match="The provided solver must be callable."):
146+
clustering.set_solver("not_a_callable")
147+
148+
with pytest.raises(ValueError, match="The provided solver must be callable."):
149+
clustering.set_solver(12345) # Not a callable
150+
with pytest.raises(ValueError, match="The provided solver must be callable."):
151+
clustering.set_solver(None)
152+
153+
def invalid_signature():
154+
return [0, 1], 0.1
155+
156+
with pytest.raises(ValueError):
157+
clustering.set_solver(invalid_signature)
158+
def invalid_clusterer(n, edges, nb_edges):
159+
return [0, 1], 0.1
160+
with pytest.raises(ValueError):
161+
clustering.set_solver(invalid_clusterer)

0 commit comments

Comments
 (0)