Skip to content

Commit 445aaf8

Browse files
Add float32 compatibility to KMedoids (#120)
* add float32 compatibility * black * fix CI and add test on type * add test on transform
1 parent c1dfd25 commit 445aaf8

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

sklearn_extra/cluster/_k_medoids.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
9494
>>> kmedoids.predict([[0,0], [4,4]])
9595
array([0, 1])
9696
>>> kmedoids.cluster_centers_
97-
array([[1, 2],
98-
[4, 2]])
97+
array([[1., 2.],
98+
[4., 2.]])
9999
>>> kmedoids.inertia_
100100
8.0
101101
@@ -185,7 +185,9 @@ def fit(self, X, y=None):
185185
random_state_ = check_random_state(self.random_state)
186186

187187
self._check_init_args()
188-
X = check_array(X, accept_sparse=["csr", "csc"])
188+
X = check_array(
189+
X, accept_sparse=["csr", "csc"], dtype=[np.float64, np.float32]
190+
)
189191
if self.n_clusters > X.shape[0]:
190192
raise ValueError(
191193
"The number of medoids (%d) must be less "
@@ -315,7 +317,9 @@ def transform(self, X):
315317
X_new : {array-like, sparse matrix}, shape=(n_query, n_clusters)
316318
X transformed in the new space of distances to cluster centers.
317319
"""
318-
X = check_array(X, accept_sparse=["csr", "csc"])
320+
X = check_array(
321+
X, accept_sparse=["csr", "csc"], dtype=[np.float64, np.float32]
322+
)
319323

320324
if self.metric == "precomputed":
321325
check_is_fitted(self, "medoid_indices_")
@@ -345,7 +349,9 @@ def predict(self, X):
345349
labels : array, shape = (n_query,)
346350
Index of the cluster each sample belongs to.
347351
"""
348-
X = check_array(X, accept_sparse=["csr", "csc"])
352+
X = check_array(
353+
X, accept_sparse=["csr", "csc"], dtype=[np.float64, np.float32]
354+
)
349355

350356
if self.metric == "precomputed":
351357
check_is_fitted(self, "medoid_indices_")

sklearn_extra/cluster/tests/test_k_medoids.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@
3434
@pytest.mark.parametrize(
3535
"init", ["random", "heuristic", "build", "k-medoids++"]
3636
)
37-
def test_kmedoid_results(method, init):
37+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
38+
def test_kmedoid_results(method, init, dtype):
3839
expected = np.hstack([np.zeros(50), np.ones(50)])
3940
km = KMedoids(n_clusters=2, init=init, method=method, random_state=rng)
40-
km.fit(X_cc)
41+
km.fit(X_cc.astype(dtype))
4142
# This test use data that are not perfectly separable so the
4243
# accuracy is not 1. Accuracy around 0.85
4344
assert (np.mean(km.labels_ == expected) > 0.8) or (
4445
1 - np.mean(km.labels_ == expected) > 0.8
4546
)
47+
assert dtype is np.dtype(km.cluster_centers_.dtype).type
48+
assert dtype is np.dtype(km.transform(X_cc.astype(dtype)).dtype).type
4649

4750

4851
def test_medoids_invalid_method():

0 commit comments

Comments
 (0)