Skip to content

Commit 476f11f

Browse files
EXA Add example for KMedoids (#67)
1 parent a6569b2 commit 476f11f

File tree

2 files changed

+148
-70
lines changed

2 files changed

+148
-70
lines changed

examples/plot_clustering.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===================================================================
4+
A demo of several clustering algorithms on a corrupted dataset
5+
===================================================================
6+
In this example we exhibit the results of various
7+
scikit-learn and scikit-learn-extra clustering algorithms on
8+
a dataset with outliers.
9+
KMedoids is the most stable and efficient
10+
algorithm for this application (change the seed to
11+
see different behavior for SpectralClustering and
12+
the robust kmeans).
13+
The mean-shift algorithm, once correctly
14+
parameterized, detects the outliers as a class of
15+
their own.
16+
"""
17+
print(__doc__)
18+
19+
import time
20+
21+
import numpy as np
22+
import matplotlib.pyplot as plt
23+
24+
from sklearn import cluster, mixture
25+
from sklearn.cluster import MiniBatchKMeans, KMeans
26+
from sklearn.datasets import make_blobs
27+
from sklearn.utils import shuffle
28+
29+
from sklearn_extra.robust import RobustWeightedEstimator
30+
from sklearn_extra.cluster import KMedoids
31+
32+
rng = np.random.RandomState(42)
33+
34+
centers = [[1, 1], [-1, -1], [1, -1]]
35+
n_clusters = len(centers)
36+
37+
kmeans = KMeans(n_clusters=n_clusters, random_state=rng)
38+
kmedoid = KMedoids(n_clusters=n_clusters, random_state=rng)
39+
40+
41+
def kmeans_loss(X, pred):
42+
return np.array(
43+
[
44+
np.linalg.norm(X[pred[i]] - np.mean(X[pred == pred[i]])) ** 2
45+
for i in range(len(X))
46+
]
47+
)
48+
49+
50+
two_means = cluster.MiniBatchKMeans(n_clusters=n_clusters, random_state=rng)
51+
spectral = cluster.SpectralClustering(
52+
n_clusters=n_clusters,
53+
eigen_solver="arpack",
54+
affinity="nearest_neighbors",
55+
random_state=rng,
56+
)
57+
dbscan = cluster.DBSCAN()
58+
optics = cluster.OPTICS(min_samples=20, xi=0.1, min_cluster_size=0.2)
59+
affinity_propagation = cluster.AffinityPropagation(
60+
damping=0.75, preference=-220, random_state=rng
61+
)
62+
birch = cluster.Birch(n_clusters=n_clusters)
63+
gmm = mixture.GaussianMixture(
64+
n_components=n_clusters, covariance_type="full", random_state=rng
65+
)
66+
67+
68+
for n_samples in [300, 3000]:
69+
# Construct the dataset
70+
X, labels_true = make_blobs(
71+
n_samples=n_samples, centers=centers, cluster_std=0.4, random_state=rng
72+
)
73+
74+
# Change the first 1% entries to outliers
75+
for f in range(int(n_samples / 100)):
76+
X[f] = [20, 3] + rng.normal(size=2) * 0.1
77+
# Shuffle the data so that we don't know where the outlier is.
78+
X = shuffle(X, random_state=rng)
79+
80+
# Define two other clustering algorithms
81+
kmeans_rob = RobustWeightedEstimator(
82+
MiniBatchKMeans(
83+
n_clusters, batch_size=len(X), init="random", random_state=rng
84+
),
85+
# in theory, init=kmeans++ is very non-robust
86+
burn_in=0,
87+
eta0=0.01,
88+
weighting="mom",
89+
loss=kmeans_loss,
90+
max_iter=100,
91+
k=int(n_samples / 50),
92+
random_state=rng,
93+
)
94+
bandwidth = cluster.estimate_bandwidth(X, 0.2)
95+
96+
ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
97+
98+
clustering_algorithms = (
99+
("MiniBatchKMeans", two_means),
100+
("AffinityPropagation", affinity_propagation),
101+
("MeanShift", ms),
102+
("SpectralClustering", spectral),
103+
("DBSCAN", dbscan),
104+
("OPTICS", optics),
105+
("Birch", birch),
106+
("GaussianMixture", gmm),
107+
("K-Medoid", kmedoid),
108+
("Robust K-Means", kmeans_rob),
109+
)
110+
111+
plot_num = 1
112+
fig = plt.figure(figsize=(9 * 2 + 3, 5))
113+
plt.subplots_adjust(
114+
left=0.02, right=0.98, bottom=0.001, top=0.85, wspace=0.05, hspace=0.18
115+
)
116+
117+
for name, algorithm in clustering_algorithms:
118+
t0 = time.time()
119+
algorithm.fit(X)
120+
t1 = time.time()
121+
122+
if hasattr(algorithm, "labels_"):
123+
y_pred = algorithm.labels_.astype(np.int)
124+
else:
125+
y_pred = algorithm.predict(X)
126+
127+
plt.subplot(2, int(len(clustering_algorithms) / 2), plot_num)
128+
plt.title(name, size=18)
129+
130+
plt.scatter(X[:, 0], X[:, 1], s=10, c=y_pred)
131+
132+
plt.xticks(())
133+
plt.yticks(())
134+
plt.text(
135+
0.99,
136+
0.01,
137+
("%.2fs" % (t1 - t0)).lstrip("0"),
138+
transform=plt.gca().transAxes,
139+
size=15,
140+
horizontalalignment="right",
141+
)
142+
plt.suptitle(
143+
f"Dataset with {n_samples} samples, {n_samples // 100} outliers.",
144+
size=20,
145+
)
146+
plot_num += 1
147+
148+
plt.show()

examples/plot_robust_kmeans.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)