Skip to content

Commit 168c113

Browse files
Merge pull request #443 from SaoirseARM:toupstream/cluster_kmeans_plus_plus
PiperOrigin-RevId: 326691434
2 parents c4ad2ce + efe9af4 commit 168c113

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class CentroidInitialization(str, enum.Enum):
2929
axis is evenly spaced into as many regions as many clusters we want to
3030
have. After this the corresponding X values are obtained and used to
3131
initialize the clusters centroids.
32+
* `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm
3233
"""
3334
LINEAR = "LINEAR"
3435
RANDOM = "RANDOM"
3536
DENSITY_BASED = "DENSITY_BASED"
37+
KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS"

tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
import abc
1818
import six
1919
import tensorflow as tf
20-
20+
from tensorflow.python.ops import clustering_ops
2121
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2222

2323
k = tf.keras.backend
2424
CentroidInitialization = cluster_config.CentroidInitialization
25-
2625
@six.add_metaclass(abc.ABCMeta)
2726
class AbstractCentroidsInitialisation:
2827
"""
@@ -53,6 +52,20 @@ def get_cluster_centroids(self):
5352
self.number_of_clusters)
5453
return cluster_centroids
5554

55+
class KmeansPlusPlusCentroidsInitialisation(AbstractCentroidsInitialisation):
56+
"""
57+
Cluster centroids based on kmeans++ algorithm
58+
"""
59+
def get_cluster_centroids(self):
60+
61+
weights = tf.reshape(self.weights, [-1, 1])
62+
63+
cluster_centroids = clustering_ops.kmeans_plus_plus_initialization(weights,
64+
self.number_of_clusters,
65+
seed=9,
66+
num_retries_per_sample=-1)
67+
68+
return cluster_centroids
5669

5770
class RandomCentroidsInitialisation(AbstractCentroidsInitialisation):
5871
"""
@@ -192,7 +205,9 @@ class CentroidsInitializerFactory:
192205
CentroidInitialization.LINEAR : LinearCentroidsInitialisation,
193206
CentroidInitialization.RANDOM : RandomCentroidsInitialisation,
194207
CentroidInitialization.DENSITY_BASED :
195-
DensityBasedCentroidsInitialisation
208+
DensityBasedCentroidsInitialisation,
209+
CentroidInitialization.KMEANS_PLUS_PLUS :
210+
KmeansPlusPlusCentroidsInitialisation,
196211
}
197212

198213
@classmethod

tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def setUp(self):
4040
(CentroidInitialization.LINEAR),
4141
(CentroidInitialization.RANDOM),
4242
(CentroidInitialization.DENSITY_BASED),
43+
(CentroidInitialization.KMEANS_PLUS_PLUS),
4344
)
4445
def testExistingInitsAreSupported(self, init_type):
4546
"""
@@ -63,6 +64,10 @@ def testNonExistingInitIsNotSupported(self):
6364
CentroidInitialization.DENSITY_BASED,
6465
clustering_centroids.DensityBasedCentroidsInitialisation
6566
),
67+
(
68+
CentroidInitialization.KMEANS_PLUS_PLUS,
69+
clustering_centroids.KmeansPlusPlusCentroidsInitialisation
70+
),
6671
)
6772
def testReturnsMethodForExistingInit(self, init_type, method):
6873
"""
@@ -177,6 +182,30 @@ def testClusterCentroids(self, weights, number_of_clusters, centroids):
177182
calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0]
178183
self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4)
179184

185+
@parameterized.parameters(
186+
(
187+
[0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5],
188+
5,
189+
[3.1, 0., 2., 1., 3.4]
190+
),
191+
(
192+
[0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5],
193+
3,
194+
[3.1, 0., 2.]
195+
),
196+
(
197+
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
198+
3,
199+
[6., 1., 8.]
200+
)
201+
)
202+
def testKmeanPlusPlusValues(self, weights, number_of_clusters, centroids):
203+
kmci = clustering_centroids.KmeansPlusPlusCentroidsInitialisation(
204+
weights,
205+
number_of_clusters
206+
)
207+
calc_centroids = K.batch_get_value([kmci.get_cluster_centroids()])[0]
208+
self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4)
180209

181210
if __name__ == '__main__':
182211
test.main()

0 commit comments

Comments
 (0)