Skip to content

Commit 75dbaa4

Browse files
FIG bug indices medoids CLARA (#127)
1 parent 5c47ba2 commit 75dbaa4

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

sklearn_extra/cluster/_k_medoids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def fit(self, X, y=None):
685685
medoids_idxs = pam.medoid_indices_
686686
best_sample_idxs = sample_idxs
687687

688-
self.medoid_indices_ = medoids_idxs
688+
self.medoid_indices_ = sample_idxs[medoids_idxs]
689689
self.labels_ = np.argmin(self.transform(X), axis=1)
690690
self.n_iter_ = self.n_sampling_iter
691691

sklearn_extra/cluster/tests/test_k_medoids.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,22 @@ def test_seuclidean():
390390
km.predict(np.array([0, 0, 0, 1]).reshape((4, 1)))
391391
km.transform(np.array([0, 0, 0, 1]).reshape((4, 1)))
392392
assert len(record) == 0
393+
394+
395+
def test_medoids_indices():
396+
rng = np.random.RandomState(seed)
397+
X_iris = load_iris()["data"]
398+
399+
clara = CLARA(
400+
n_clusters=3,
401+
n_sampling_iter=1,
402+
n_sampling=len(X_iris),
403+
random_state=rng,
404+
)
405+
406+
model = KMedoids(n_clusters=3, init="build", random_state=rng)
407+
408+
model.fit(X_iris)
409+
clara.fit(X_iris)
410+
assert_array_equal(X_iris[model.medoid_indices_], model.cluster_centers_)
411+
assert_array_equal(X_iris[clara.medoid_indices_], clara.cluster_centers_)

0 commit comments

Comments
 (0)