Skip to content

Commit c210097

Browse files
committed
Updated Clustering initialization default to KMEANS_PLUS_PLUS
* Added integration test which ensures the default is KMEANS_PLUS_PLUS
1 parent f676a3b commit c210097

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818

1919
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2020
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
21+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2122

2223
k = tf.keras.backend
2324
CustomObjectScope = tf.keras.utils.CustomObjectScope
25+
CentroidInitialization = cluster_config.CentroidInitialization
2426
Layer = tf.keras.layers.Layer
2527
InputLayer = tf.keras.layers.InputLayer
2628

@@ -53,7 +55,7 @@ def cluster_scope():
5355

5456
def cluster_weights(to_cluster,
5557
number_of_clusters,
56-
cluster_centroids_init,
58+
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
5759
**kwargs):
5860
"""Modifies a keras layer or model to be clustered during training.
5961

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def setUp(self):
4040
super(ClusterIntegrationTest, self).setUp()
4141
self.params = {
4242
"number_of_clusters": 8,
43-
"cluster_centroids_init": CentroidInitialization.LINEAR,
4443
}
4544

4645
self.x_train = np.array(
@@ -135,6 +134,19 @@ def _get_number_of_unique_weights(self, stripped_model, layer_nr,
135134
nr_of_unique_weights = len(set(weights_as_list))
136135
return nr_of_unique_weights
137136

137+
def testDefaultClusteringInit(self):
138+
"""Verifies that default initialization method is KMEANS_PLUS_PLUS."""
139+
original_model = keras.Sequential([
140+
layers.Dense(5, input_shape=(5,)),
141+
layers.Dense(5),
142+
])
143+
clustered_model = cluster.cluster_weights(original_model, **self.params)
144+
145+
for i in range(len(clustered_model.layers)):
146+
if hasattr(clustered_model.layers[i], "layer"):
147+
init_method = clustered_model.layers[i].get_config()["cluster_centroids_init"]
148+
self.assertEqual(init_method, CentroidInitialization.KMEANS_PLUS_PLUS)
149+
138150
@keras_parameterized.run_all_keras_modes
139151
def testValuesRemainClusteredAfterTraining(self):
140152
"""Verifies that training a clustered model does not destroy the clusters."""
@@ -179,7 +191,6 @@ def testSparsityIsPreservedDuringTraining(self):
179191
original_model.layers[0].set_weights(first_layer_weights)
180192
clustering_params = {
181193
"number_of_clusters": 6,
182-
"cluster_centroids_init": CentroidInitialization.LINEAR,
183194
"preserve_sparsity": True
184195
}
185196
clustered_model = experimental_cluster.cluster_weights(
@@ -352,7 +363,6 @@ def setUp(self):
352363

353364
self.params_clustering = {
354365
"number_of_clusters": 16,
355-
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
356366
}
357367

358368
def _train(self, model):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ClusterWeights(Wrapper):
5454
def __init__(self,
5555
layer,
5656
number_of_clusters,
57-
cluster_centroids_init,
57+
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
5858
preserve_sparsity=False,
5959
cluster_gradient_aggregation=GradientAggregation.SUM,
6060
**kwargs):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def testCannotBeInitializedWithNonLayerObject(self):
6161
cluster_wrapper.ClusterWeights(
6262
{'this': 'is not a Layer instance'},
6363
number_of_clusters=13,
64-
cluster_centroids_init=CentroidInitialization.LINEAR
6564
)
6665

6766
def testCannotBeInitializedWithNonClusterableLayer(self):
@@ -70,15 +69,13 @@ def testCannotBeInitializedWithNonClusterableLayer(self):
7069
cluster_wrapper.ClusterWeights(
7170
NonClusterableLayer(10),
7271
number_of_clusters=13,
73-
cluster_centroids_init=CentroidInitialization.LINEAR
7472
)
7573

7674
def testCanBeInitializedWithClusterableLayer(self):
7775
"""Verifies that ClusterWeights can be initialized with a built-in clusterable layer."""
7876
l = cluster_wrapper.ClusterWeights(
7977
layers.Dense(10),
8078
number_of_clusters=13,
81-
cluster_centroids_init=CentroidInitialization.LINEAR
8279
)
8380
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
8481

@@ -87,16 +84,14 @@ def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
8784
with self.assertRaises(ValueError):
8885
cluster_wrapper.ClusterWeights(
8986
layers.Dense(10),
90-
number_of_clusters='13',
91-
cluster_centroids_init=CentroidInitialization.LINEAR)
87+
number_of_clusters='13')
9288

9389
def testCannotBeInitializedWithFloatNumberOfClusters(self):
9490
"""Verifies that ClusterWeights cannot be initialized with a decimal value provided for the number of clusters."""
9591
with self.assertRaises(ValueError):
9692
cluster_wrapper.ClusterWeights(
9793
layers.Dense(10),
9894
number_of_clusters=13.4,
99-
cluster_centroids_init=CentroidInitialization.LINEAR
10095
)
10196

10297
@parameterized.parameters(
@@ -111,7 +106,6 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo(
111106
cluster_wrapper.ClusterWeights(
112107
layers.Dense(10),
113108
number_of_clusters=number_of_clusters,
114-
cluster_centroids_init=CentroidInitialization.LINEAR
115109
)
116110

117111
@parameterized.parameters((0), (2), (-32))
@@ -122,7 +116,6 @@ def testCannotBeInitializedWithSparsityPreservationAndNumberOfClustersLessThanTh
122116
cluster_wrapper.ClusterWeights(
123117
layers.Dense(10),
124118
number_of_clusters=number_of_clusters,
125-
cluster_centroids_init=CentroidInitialization.LINEAR,
126119
preserve_sparsity=True)
127120

128121
def testCanBeInitializedWithAlreadyClusterableLayer(self):
@@ -131,7 +124,6 @@ def testCanBeInitializedWithAlreadyClusterableLayer(self):
131124
l = cluster_wrapper.ClusterWeights(
132125
layer,
133126
number_of_clusters=13,
134-
cluster_centroids_init=CentroidInitialization.LINEAR
135127
)
136128
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
137129

@@ -140,7 +132,6 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
140132
l = cluster_wrapper.ClusterWeights(
141133
layers.Dense(10, input_shape=(10,)),
142134
number_of_clusters=13,
143-
cluster_centroids_init=CentroidInitialization.LINEAR
144135
)
145136
self.assertTrue(hasattr(l, '_batch_input_shape'))
146137

@@ -183,7 +174,6 @@ def testClusterReassociation(self):
183174
l = cluster_wrapper.ClusterWeights(
184175
keras.layers.Dense(8, input_shape=input_shape),
185176
number_of_clusters=2,
186-
cluster_centroids_init=CentroidInitialization.LINEAR
187177
)
188178
# Build a layer with the given shape
189179
l.build(input_shape)
@@ -239,7 +229,6 @@ def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self):
239229
original_layer = cluster_wrapper.ClusterWeights(
240230
keras.layers.Dense(8, input_shape=input_shape),
241231
number_of_clusters=2,
242-
cluster_centroids_init=CentroidInitialization.LINEAR
243232
)
244233
# Build a layer with the given shape
245234
original_layer.build(input_shape)

tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""Experimental clustering API functions for Keras models."""
1616

1717
from tensorflow_model_optimization.python.core.clustering.keras.cluster import _cluster_weights
18+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
1819

20+
CentroidInitialization = cluster_config.CentroidInitialization
1921

20-
def cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
21-
preserve_sparsity, **kwargs):
22+
def cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
23+
preserve_sparsity=False, **kwargs):
2224
"""Modify a keras layer or model to be clustered during training (experimental).
2325
2426
This function wraps a keras model or layer with clustering functionality

0 commit comments

Comments
 (0)