Skip to content

Commit 2cc363d

Browse files
committed
API modifications for scikit-learn
1 parent bc3b081 commit 2cc363d

File tree

5 files changed

+100
-43
lines changed

5 files changed

+100
-43
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ docs/build/
33

44
# env and caches
55

6-
mdsenv/
6+
mds-env/
77
**/__pycache__/
88
.pytest_cache/
99
.ruff_cache/

src/radius_clustering/radius_clustering.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import numpy as np
1313
from sklearn.metrics import pairwise_distances
1414
from sklearn.base import BaseEstimator, ClusterMixin
15-
from sklearn.utils.validation import check_array
15+
from sklearn.utils.validation import check_array, validate_data, check_random_state
1616

1717
from radius_clustering.utils._emos import py_emos_main
1818
from radius_clustering.utils._mds_approx import solve_mds
1919

2020
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
2121

2222

23-
class RadiusClustering(BaseEstimator, ClusterMixin):
23+
class RadiusClustering(ClusterMixin, BaseEstimator):
2424
"""
2525
Radius Clustering algorithm.
2626
@@ -42,29 +42,56 @@ class RadiusClustering(BaseEstimator, ClusterMixin):
4242
The indices of the cluster centers.
4343
labels\_ : array-like, shape (n_samples,)
4444
The cluster labels for each point in the input data.
45-
effective_radius : float
45+
effective_radius\_ : float
4646
The maximum distance between any point and its assigned cluster center.
47+
random_state\_ : int | None
48+
The random state used for reproducibility. If None, no random state is set.
49+
50+
.. note::
51+
The `random_state_` attribute is not used when the `manner` is set to "exact".
52+
53+
.. versionadded:: 1.3.0
54+
The *random_state* parameter was added to allow reproducibility in the approximate method.
55+
56+
.. versionchanged:: 1.3.0
57+
All publicly accessible attributes are now suffixed with an underscore (e.g., `centers_`, `labels_`).
58+
This is particularly useful for compatibility with scikit-learn's API.
4759
"""
4860

49-
def __init__(self, manner="approx", threshold=0.5):
61+
_estimator_type = "clusterer"
62+
63+
def __init__(self, manner: str ="approx", threshold: float =0.5, random_state: int | None = None) -> None:
5064
self.manner = manner
5165
self.threshold = threshold
66+
self.random_state = random_state
5267

53-
def _check_symmetric(self, a, tol=1e-8):
68+
def _check_symmetric(self, a: np.ndarray, tol: float =1e-8) -> bool:
5469
if a.ndim != 2:
5570
raise ValueError("Input must be a 2D array.")
5671
if a.shape[0] != a.shape[1]:
5772
return False
5873
return np.allclose(a, a.T, atol=tol)
5974

60-
def fit(self, X, y=None):
75+
def fit(self, X: np.ndarray, y: None = None) -> "RadiusClustering":
6176
"""
6277
Fit the MDS clustering model to the input data.
6378
79+
This method computes the distance matrix if the input is a feature matrix,
80+
or uses the provided distance matrix directly if the input is already a distance matrix.
81+
82+
.. note::
83+
If the input is a distance matrix, it should be symmetric and square.
84+
If the input is a feature matrix, the distance matrix will be computed using Euclidean distance.
85+
86+
.. tip::
87+
Next version will support providing different metrics or even custom callables to compute the distance matrix.
88+
6489
Parameters:
6590
-----------
6691
X : array-like, shape (n_samples, n_features)
67-
The input data to cluster.
92+
The input data to cluster. X should be a 2D array-like structure. It can either be :
93+
- A distance matrix (symmetric, square) with shape (n_samples, n_samples).
94+
- A feature matrix with shape (n_samples, n_features) where the distance matrix will be computed.
6895
y : Ignored
6996
Not used, present here for API consistency by convention.
7097
@@ -91,38 +118,43 @@ def fit(self, X, y=None):
91118
For examples on common datasets and differences with kmeans,
92119
see :ref:`sphx_glr_auto_examples_plot_iris_example.py`
93120
"""
94-
self.X = check_array(X)
121+
self.X_checked_ = validate_data(self, X)
95122

96123
# Create dist and adj matrices
97-
if not self._check_symmetric(self.X):
98-
dist_mat = pairwise_distances(self.X, metric="euclidean")
124+
if not self._check_symmetric(self.X_checked_):
125+
dist_mat = pairwise_distances(self.X_checked_, metric="euclidean")
99126
else:
100-
dist_mat = self.X
127+
dist_mat = self.X_checked_
101128
adj_mask = np.triu((dist_mat <= self.threshold), k=1)
102-
self.nb_edges = np.sum(adj_mask)
103-
if self.nb_edges == 0:
104-
self.centers_ = list(range(self.X.shape[0]))
105-
self.labels_ = self.centers_
106-
self.effective_radius = 0
107-
self._mds_exec_time = 0
129+
self.nb_edges_ = np.sum(adj_mask)
130+
if self.nb_edges_ == 0:
131+
self.centers_ = list(range(self.X_checked_.shape[0]))
132+
self.labels_ = np.array(self.centers_)
133+
self.effective_radius_ = 0
134+
self.mds_exec_time_ = 0
108135
return self
109-
self.edges = np.argwhere(adj_mask).astype(np.uint32) #TODO: changer en uint32
110-
self.dist_mat = dist_mat
136+
self.edges_ = np.argwhere(adj_mask).astype(np.uint32) # Edges in the adjacency matrix
137+
# uint32 is used to use less memory. Max number of features is 2^32-1
138+
self.dist_mat_ = dist_mat
111139

112140
self._clustering()
113141
self._compute_effective_radius()
114142
self._compute_labels()
115143

116144
return self
117145

118-
def fit_predict(self, X, y=None):
146+
def fit_predict(self, X: np.ndarray, y: None = None) -> np.ndarray:
119147
"""
120148
Fit the model and return the cluster labels.
121149
150+
This method is a convenience function that combines `fit` and `predict`.
151+
122152
Parameters:
123153
-----------
124154
X : array-like, shape (n_samples, n_features)
125-
The input data to cluster.
155+
The input data to cluster. X should be a 2D array-like structure. It can either be :
156+
- A distance matrix (symmetric, square) with shape (n_samples, n_samples).
157+
- A feature matrix with shape (n_samples, n_features) where the distance matrix will be computed.
126158
y : Ignored
127159
Not used, present here for API consistency by convention.
128160
@@ -138,13 +170,13 @@ def _clustering(self):
138170
"""
139171
Perform the clustering using either the exact or approximate MDS method.
140172
"""
141-
n = self.X.shape[0]
173+
n = self.X_checked_.shape[0]
142174
if self.manner == "exact":
143175
self._clustering_exact(n)
144176
else:
145177
self._clustering_approx(n)
146178

147-
def _clustering_exact(self, n):
179+
def _clustering_exact(self, n: int) -> None:
148180
"""
149181
Perform exact MDS clustering.
150182
@@ -158,13 +190,26 @@ def _clustering_exact(self, n):
158190
This function uses the EMOS algorithm to solve the MDS problem.
159191
See: [jiang]_ for more details.
160192
"""
161-
self.centers_, self._mds_exec_time = py_emos_main(
162-
self.edges.flatten(), n, self.nb_edges
193+
self.centers_, self.mds_exec_time_ = py_emos_main(
194+
self.edges_.flatten(), n, self.nb_edges_
163195
)
164196

165-
def _clustering_approx(self, n):
197+
def _clustering_approx(self, n: int) -> None:
166198
"""
167-
Perform approximate MDS clustering.
199+
Perform approximate MDS clustering. This method uses a pretty trick to set the seed for the random state of the C++ code of the MDS solver.
200+
201+
.. tip::
202+
The random state is used to ensure reproducibility of the results when using the approximate method.
203+
If `random_state` is None, a default value of 42 is used.
204+
205+
.. important::
206+
:collapsible: closed
207+
The trick to set the random state is :
208+
1. Use the `check_random_state` function to get a `RandomState`singleton instance, set up with the provided `random_state`.
209+
2. Use the `randint` method of the `RandomState` instance to generate a random integer.
210+
3. Use this random integer as the seed for the C++ code of the MDS solver.
211+
212+
This ensures that the seed passed to the C++ code is always an integer, which is required by the MDS solver, and allows for reproducibility of the results.
168213
169214
Parameters:
170215
-----------
@@ -176,9 +221,13 @@ def _clustering_approx(self, n):
176221
This function uses the approximation method to solve the MDS problem.
177222
See [casado]_ for more details.
178223
"""
179-
result = solve_mds(n, self.edges.flatten().astype(np.int32), self.nb_edges, "test")
224+
if self.random_state is None:
225+
self.random_state = 42
226+
self.random_state_ = check_random_state(self.random_state)
227+
seed = self.random_state_.randint(np.iinfo(np.int32).max)
228+
result = solve_mds(n, self.edges_.flatten().astype(np.int32), self.nb_edges_, seed)
180229
self.centers_ = [x for x in result["solution_set"]]
181-
self._mds_exec_time = result["Time"]
230+
self.mds_exec_time_ = result["Time"]
182231

183232
def _compute_effective_radius(self):
184233
"""
@@ -187,13 +236,13 @@ def _compute_effective_radius(self):
187236
The effective radius is the maximum radius among all clusters.
188237
That means EffRad = max(R(C_i)) for all i.
189238
"""
190-
self.effective_radius = np.min(self.dist_mat[:, self.centers_], axis=1).max()
239+
self.effective_radius_ = np.min(self.dist_mat_[:, self.centers_], axis=1).max()
191240

192241
def _compute_labels(self):
193242
"""
194243
Compute the cluster labels for each point in the dataset.
195244
"""
196-
distances = self.dist_mat[:, self.centers_]
245+
distances = self.dist_mat_[:, self.centers_]
197246
self.labels_ = np.argmin(distances, axis=1)
198247

199248
min_dist = np.min(distances, axis=1)

src/radius_clustering/utils/mds.pyx

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ cdef extern from "mds_core.cpp":
3737
cpp_unordered_set[int] getSolutionSet()
3838
void setSolutionSet(cpp_unordered_set[int] solutionSet)
3939

40-
cdef Result iterated_greedy_wrapper(int numNodes, const vector[int]& edges_list, int nb_edges, string name) nogil
40+
cdef Result iterated_greedy_wrapper(int numNodes, const vector[int]& edges_list, int nb_edges, long seed) nogil
4141

42-
def solve_mds(int num_nodes, np.ndarray[int, ndim=1, mode="c"] edges not None, int nb_edges, str name):
42+
def solve_mds(int num_nodes, np.ndarray[int, ndim=1, mode="c"] edges not None, int nb_edges, int seed):
4343
"""
4444
Solve the Minimum Dominating Set problem for a given graph.
4545
@@ -64,15 +64,12 @@ def solve_mds(int num_nodes, np.ndarray[int, ndim=1, mode="c"] edges not None, i
6464
# Cast the NumPy array to a C++ vector
6565
cpp_edge_list.assign(&edges[0], &edges[0] + edges.shape[0])
6666

67-
cdef string instanceName = name.encode('utf-8')
68-
6967
cdef Result result
7068
with nogil:
71-
result = iterated_greedy_wrapper(num_nodes, cpp_edge_list, nb_edges, instanceName)
69+
result = iterated_greedy_wrapper(num_nodes, cpp_edge_list, nb_edges, seed)
7270

7371
# Convert the C++ Result to a Python dictionary
7472
py_result = {
75-
"instance_name": result.getInstanceName().decode('utf-8'),
7673
"solution_set": set(result.getSolutionSet()),
7774
}
7875

src/radius_clustering/utils/mds_core.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,18 +449,18 @@ class Main {
449449
public:
450450
Main() : algorithm(constructive, localSearch) {}
451451

452-
Result execute(int numNodes, const std::vector<int>& edges_list, int nb_edges, std::string name) {
453-
Instance instance(numNodes, edges_list, nb_edges, name);
454-
RandomManager::setSeed(13);
452+
Result execute(int numNodes, const std::vector<int>& edges_list, int nb_edges, long seed) {
453+
Instance instance(numNodes, edges_list, nb_edges, "name");
454+
RandomManager::setSeed(seed);
455455
signal(SIGINT, signal_handler);
456456
return algorithm.execute(instance);
457457
}
458458
};
459459

460460
extern "C" {
461-
inline Result iterated_greedy_wrapper(int numNodes, const std::vector<int>& edges_list, int nb_edges, std::string name) {
461+
inline Result iterated_greedy_wrapper(int numNodes, const std::vector<int>& edges_list, int nb_edges, long seed) {
462462
static Main main; // Create a single static instance
463463

464-
return main.execute(numNodes, edges_list, nb_edges, name);
464+
return main.execute(numNodes, edges_list, nb_edges, seed);
465465
}
466466
}

tests/test_rad.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
from logging import getLogger
2+
3+
logger = getLogger(__name__)
4+
logger.setLevel("INFO")
5+
16
def test_imports():
27
import radius_clustering as rad
38

49

510
def test_from_import():
611
from radius_clustering import RadiusClustering
712

13+
def test_check_estimator_api_consistency():
14+
from radius_clustering import RadiusClustering
15+
from sklearn.utils.estimator_checks import check_estimator
16+
17+
# Check the API consistency of the RadiusClustering estimator
18+
stats = check_estimator(RadiusClustering())
819

920
def test_radius_clustering_approx():
1021
from radius_clustering import RadiusClustering

0 commit comments

Comments
 (0)