Skip to content

Commit 8c605e1

Browse files
committed
refactoring algorithms + doc enhancement
1 parent 3dceb50 commit 8c605e1

File tree

4 files changed

+116
-76
lines changed

4 files changed

+116
-76
lines changed

.coverage

0 Bytes
Binary file not shown.

docs/source/api.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
API Reference
22
=============
33

4-
.. automodule:: radius_clustering
4+
This page documents the implementation details of the `radius_clustering` package.
5+
6+
RadiusClustering Class
7+
----------------------
8+
9+
.. autoclass:: radius_clustering.RadiusClustering
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
Algorithms Module
15+
-----------------
16+
.. automodule:: radius_clustering.algorithms
517
:members:
618
:undoc-members:
719
:show-inheritance:

src/radius_clustering/algorithms.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
This module contains the implementation of the clustering algorithms.
3+
It provides two main functions: `clustering_approx` and `clustering_exact`.
4+
5+
These functions can be replaced in the `RadiusClustering` class
6+
to perform clustering using another algorithm.
7+
8+
.. versionadded:: 2.0.0
9+
Refactoring the structure of the code to separate the clustering algorithms
10+
This allows for easier maintenance and extensibility of the codebase.
11+
Plus, this allows for the addition of new clustering algorithms
12+
such as `Curgraph` added in this version.
13+
"""
14+
from __future__ import annotations
15+
16+
import numpy as np
17+
18+
from .utils._mds_approx import solve_mds
19+
from .utils._emos import py_emos_main
20+
21+
def clustering_approx(
22+
n: int, edges: np.ndarray, nb_edges: int,
23+
random_state: int | None = None) -> None:
24+
"""
25+
Perform approximate MDS clustering.
26+
This method uses a pretty trick to set the seed for
27+
the random state of the C++ code of the MDS solver.
28+
29+
.. tip::
30+
The random state is used to ensure reproducibility of the results
31+
when using the approximate method.
32+
If `random_state` is None, a default value of 42 is used.
33+
34+
.. important::
35+
The trick to set the random state is :
36+
37+
1. Use the `check_random_state` function to get a `RandomState`singleton
38+
instance, set up with the provided `random_state`.
39+
40+
2. Use the `randint` method of the `RandomState` instance to generate a
41+
random integer.
42+
43+
3. Use this random integer as the seed for the C++ code of the MDS solver.
44+
45+
This ensures that the seed passed to the C++ code is always an integer,
46+
which is required by the MDS solver, and allows for
47+
reproducibility of the results.
48+
49+
Parameters:
50+
-----------
51+
n : int
52+
The number of points in the dataset.
53+
54+
Notes:
55+
------
56+
This function uses the approximation method to solve the MDS problem.
57+
See [casado]_ for more details.
58+
"""
59+
result = solve_mds(
60+
n, edges.flatten().astype(np.int32), nb_edges, random_state
61+
)
62+
centers = sorted([x for x in result["solution_set"]])
63+
mds_exec_time = result["Time"]
64+
return centers, mds_exec_time
65+
66+
def clustering_exact(n: int, edges: np.ndarray, nb_edges: int) -> None:
67+
"""
68+
Perform exact MDS clustering.
69+
70+
This function uses the EMOs algorithm to solve the MDS problem.
71+
72+
.. important::
73+
The EMOS algorithm is an exact algorithm for solving the MDS problem.
74+
It is a branch and bound algorithm that uses graph theory tricks
75+
to efficiently cut the search space. See [jiang]_ for more details.
76+
77+
Parameters:
78+
-----------
79+
n : int
80+
The number of points in the dataset.
81+
"""
82+
centers, mds_exec_time = py_emos_main(
83+
edges.flatten(), n, nb_edges
84+
)
85+
centers.sort()
86+
return centers, mds_exec_time

src/radius_clustering/radius_clustering.py

Lines changed: 17 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from sklearn.metrics import pairwise_distances
1919
from sklearn.utils.validation import check_random_state, validate_data
2020

21-
from radius_clustering.utils._emos import py_emos_main
22-
from radius_clustering.utils._mds_approx import solve_mds
21+
from .algorithms import clustering_approx, clustering_exact
2322

2423
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
2524

@@ -53,20 +52,23 @@ class RadiusClustering(ClusterMixin, BaseEstimator):
5352
5453
.. note::
5554
The `random_state_` attribute is not used when the `manner` is set to "exact".
55+
56+
.. versionchanged:: 2.0.0
57+
The `RadiusClustering` class has been refactored.
58+
Clustering algorithms are now separated into their own module
59+
(`algorithms.py`) to improve maintainability and extensibility.
5660
5761
.. versionadded:: 1.3.0
58-
The *random_state* parameter was added to allow reproducibility in
59-
the approximate method.
62+
63+
- The *random_state* parameter was added to allow reproducibility in the approximate method.
64+
65+
- The `radius` parameter replaces the `threshold` parameter for setting the dissimilarity threshold for better clarity and consistency.
6066
6167
.. versionchanged:: 1.3.0
6268
All publicly accessible attributes are now suffixed with an underscore
6369
(e.g., `centers_`, `labels_`).
6470
This is particularly useful for compatibility with scikit-learn's API.
6571
66-
.. versionadded:: 1.3.0
67-
The `radius` parameter replaces the `threshold` parameter for setting
68-
the dissimilarity threshold for better clarity and consistency.
69-
7072
.. deprecated:: 1.3.0
7173
The `threshold` parameter is deprecated. Use `radius` instead.
7274
Will be removed in a future version.
@@ -243,7 +245,7 @@ def fit_predict(self, X: np.ndarray, y: None = None, metric: str | callable = "e
243245
labels : array, shape (n_samples,)
244246
The cluster labels for each point in X.
245247
"""
246-
self.fit(X)
248+
self.fit(X, metric=metric)
247249
return self.labels_
248250

249251
def _clustering(self):
@@ -252,75 +254,15 @@ def _clustering(self):
252254
"""
253255
n = self.X_checked_.shape[0]
254256
if self.manner != "exact" and self.manner != "approx":
255-
print(f"Invalid manner: {self.manner}. Defaulting to 'approx'.")
256257
raise ValueError("Invalid manner. Choose either 'exact' or 'approx'.")
257258
if self.manner == "exact":
258-
self._clustering_exact(n)
259+
self.centers_, self.mds_exec_time_ = clustering_exact(n, self.edges_, self.nb_edges_)
259260
else:
260-
self._clustering_approx(n)
261-
262-
def _clustering_exact(self, n: int) -> None:
263-
"""
264-
Perform exact MDS clustering.
265-
266-
Parameters:
267-
-----------
268-
n : int
269-
The number of points in the dataset.
270-
271-
Notes:
272-
------
273-
This function uses the EMOS algorithm to solve the MDS problem.
274-
See: [jiang]_ for more details.
275-
"""
276-
self.centers_, self.mds_exec_time_ = py_emos_main(
277-
self.edges_.flatten(), n, self.nb_edges_
278-
)
279-
self.centers_.sort() # Sort the centers to ensure consistent order
280-
281-
def _clustering_approx(self, n: int) -> None:
282-
"""
283-
Perform approximate MDS clustering.
284-
This method uses a pretty trick to set the seed for
285-
the random state of the C++ code of the MDS solver.
286-
287-
.. tip::
288-
The random state is used to ensure reproducibility of the results
289-
when using the approximate method.
290-
If `random_state` is None, a default value of 42 is used.
291-
292-
.. important::
293-
:collapsible: closed
294-
The trick to set the random state is :
295-
1. Use the `check_random_state` function to get a `RandomState`singleton
296-
instance, set up with the provided `random_state`.
297-
2. Use the `randint` method of the `RandomState` instance to generate a
298-
random integer.
299-
3. Use this random integer as the seed for the C++ code of the MDS solver.
300-
301-
This ensures that the seed passed to the C++ code is always an integer,
302-
which is required by the MDS solver, and allows for
303-
reproducibility of the results.
304-
305-
Parameters:
306-
-----------
307-
n : int
308-
The number of points in the dataset.
309-
310-
Notes:
311-
------
312-
This function uses the approximation method to solve the MDS problem.
313-
See [casado]_ for more details.
314-
"""
315-
if self.random_state is None:
316-
self.random_state = 42
317-
self.random_state_ = check_random_state(self.random_state)
318-
seed = self.random_state_.randint(np.iinfo(np.int32).max)
319-
result = solve_mds(
320-
n, self.edges_.flatten().astype(np.int32), self.nb_edges_, seed
321-
)
322-
self.centers_ = sorted([x for x in result["solution_set"]])
323-
self.mds_exec_time_ = result["Time"]
261+
if self.random_state is None:
262+
self.random_state = 42
263+
self.random_state_ = check_random_state(self.random_state)
264+
seed = self.random_state_.randint(np.iinfo(np.int32).max)
265+
self.centers_, self.mds_exec_time_ = clustering_approx(n, self.edges_, self.nb_edges_, seed)
324266

325267
def _compute_effective_radius(self):
326268
"""

0 commit comments

Comments
 (0)