Skip to content

Commit d36c5a8

Browse files
Merge pull request #831 from SaoirseARM:toupstream/cluster_default_change
PiperOrigin-RevId: 399130358
2 parents ae0326e + c210097 commit d36c5a8

File tree

7 files changed

+35
-27
lines changed

7 files changed

+35
-27
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ py_strict_library(
2525
srcs_version = "PY3",
2626
visibility = ["//visibility:public"],
2727
deps = [
28+
":cluster_config",
2829
":cluster_wrapper",
2930
":clustering_centroids",
3031
# tensorflow dep1,

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import tensorflow as tf
1818

19+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
1920
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2021
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
2122

2223
k = tf.keras.backend
2324
CustomObjectScope = tf.keras.utils.CustomObjectScope
25+
CentroidInitialization = cluster_config.CentroidInitialization
2426
Layer = tf.keras.layers.Layer
2527
InputLayer = tf.keras.layers.InputLayer
2628

@@ -44,17 +46,14 @@ def cluster_scope():
4446
loaded_model = tf.keras.models.load_model(keras_file)
4547
```
4648
"""
47-
return CustomObjectScope(
48-
{
49-
'ClusterWeights': cluster_wrapper.ClusterWeights
50-
}
51-
)
49+
return CustomObjectScope({'ClusterWeights': cluster_wrapper.ClusterWeights})
5250

5351

54-
def cluster_weights(to_cluster,
55-
number_of_clusters,
56-
cluster_centroids_init,
57-
**kwargs):
52+
def cluster_weights(
53+
to_cluster,
54+
number_of_clusters,
55+
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
56+
**kwargs):
5857
"""Modifies a keras layer or model to be clustered during training.
5958
6059
This function wraps a keras model or layer with clustering functionality

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def setUp(self):
4040
super(ClusterIntegrationTest, self).setUp()
4141
self.params = {
4242
"number_of_clusters": 8,
43-
"cluster_centroids_init": CentroidInitialization.LINEAR,
4443
}
4544

4645
self.x_train = np.array(
@@ -135,6 +134,20 @@ def _get_number_of_unique_weights(self, stripped_model, layer_nr,
135134
nr_of_unique_weights = len(set(weights_as_list))
136135
return nr_of_unique_weights
137136

137+
def testDefaultClusteringInit(self):
138+
"""Verifies that default initialization method is KMEANS_PLUS_PLUS."""
139+
original_model = keras.Sequential([
140+
layers.Dense(5, input_shape=(5,)),
141+
layers.Dense(5),
142+
])
143+
clustered_model = cluster.cluster_weights(original_model, **self.params)
144+
145+
for i in range(len(clustered_model.layers)):
146+
if hasattr(clustered_model.layers[i], "layer"):
147+
init_method = clustered_model.layers[i].get_config(
148+
)["cluster_centroids_init"]
149+
self.assertEqual(init_method, CentroidInitialization.KMEANS_PLUS_PLUS)
150+
138151
@keras_parameterized.run_all_keras_modes
139152
def testValuesRemainClusteredAfterTraining(self):
140153
"""Verifies that training a clustered model does not destroy the clusters."""
@@ -179,7 +192,6 @@ def testSparsityIsPreservedDuringTraining(self):
179192
original_model.layers[0].set_weights(first_layer_weights)
180193
clustering_params = {
181194
"number_of_clusters": 6,
182-
"cluster_centroids_init": CentroidInitialization.LINEAR,
183195
"preserve_sparsity": True
184196
}
185197
clustered_model = experimental_cluster.cluster_weights(
@@ -352,7 +364,6 @@ def setUp(self):
352364

353365
self.params_clustering = {
354366
"number_of_clusters": 16,
355-
"cluster_centroids_init": CentroidInitialization.KMEANS_PLUS_PLUS,
356367
}
357368

358369
def _train(self, model):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ClusterWeights(Wrapper):
5454
def __init__(self,
5555
layer,
5656
number_of_clusters,
57-
cluster_centroids_init,
57+
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
5858
preserve_sparsity=False,
5959
cluster_gradient_aggregation=GradientAggregation.SUM,
6060
**kwargs):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def testCannotBeInitializedWithNonLayerObject(self):
6161
cluster_wrapper.ClusterWeights(
6262
{'this': 'is not a Layer instance'},
6363
number_of_clusters=13,
64-
cluster_centroids_init=CentroidInitialization.LINEAR
6564
)
6665

6766
def testCannotBeInitializedWithNonClusterableLayer(self):
@@ -70,15 +69,13 @@ def testCannotBeInitializedWithNonClusterableLayer(self):
7069
cluster_wrapper.ClusterWeights(
7170
NonClusterableLayer(10),
7271
number_of_clusters=13,
73-
cluster_centroids_init=CentroidInitialization.LINEAR
7472
)
7573

7674
def testCanBeInitializedWithClusterableLayer(self):
7775
"""Verifies that ClusterWeights can be initialized with a built-in clusterable layer."""
7876
l = cluster_wrapper.ClusterWeights(
7977
layers.Dense(10),
8078
number_of_clusters=13,
81-
cluster_centroids_init=CentroidInitialization.LINEAR
8279
)
8380
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
8481

@@ -87,16 +84,14 @@ def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
8784
with self.assertRaises(ValueError):
8885
cluster_wrapper.ClusterWeights(
8986
layers.Dense(10),
90-
number_of_clusters='13',
91-
cluster_centroids_init=CentroidInitialization.LINEAR)
87+
number_of_clusters='13')
9288

9389
def testCannotBeInitializedWithFloatNumberOfClusters(self):
9490
"""Verifies that ClusterWeights cannot be initialized with a decimal value provided for the number of clusters."""
9591
with self.assertRaises(ValueError):
9692
cluster_wrapper.ClusterWeights(
9793
layers.Dense(10),
9894
number_of_clusters=13.4,
99-
cluster_centroids_init=CentroidInitialization.LINEAR
10095
)
10196

10297
@parameterized.parameters(
@@ -111,7 +106,6 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo(
111106
cluster_wrapper.ClusterWeights(
112107
layers.Dense(10),
113108
number_of_clusters=number_of_clusters,
114-
cluster_centroids_init=CentroidInitialization.LINEAR
115109
)
116110

117111
@parameterized.parameters((0), (2), (-32))
@@ -122,7 +116,6 @@ def testCannotBeInitializedWithSparsityPreservationAndNumberOfClustersLessThanTh
122116
cluster_wrapper.ClusterWeights(
123117
layers.Dense(10),
124118
number_of_clusters=number_of_clusters,
125-
cluster_centroids_init=CentroidInitialization.LINEAR,
126119
preserve_sparsity=True)
127120

128121
def testCanBeInitializedWithAlreadyClusterableLayer(self):
@@ -131,7 +124,6 @@ def testCanBeInitializedWithAlreadyClusterableLayer(self):
131124
l = cluster_wrapper.ClusterWeights(
132125
layer,
133126
number_of_clusters=13,
134-
cluster_centroids_init=CentroidInitialization.LINEAR
135127
)
136128
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
137129

@@ -140,7 +132,6 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
140132
l = cluster_wrapper.ClusterWeights(
141133
layers.Dense(10, input_shape=(10,)),
142134
number_of_clusters=13,
143-
cluster_centroids_init=CentroidInitialization.LINEAR
144135
)
145136
self.assertTrue(hasattr(l, '_batch_input_shape'))
146137

@@ -183,7 +174,6 @@ def testClusterReassociation(self):
183174
l = cluster_wrapper.ClusterWeights(
184175
keras.layers.Dense(8, input_shape=input_shape),
185176
number_of_clusters=2,
186-
cluster_centroids_init=CentroidInitialization.LINEAR
187177
)
188178
# Build a layer with the given shape
189179
l.build(input_shape)
@@ -239,7 +229,6 @@ def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self):
239229
original_layer = cluster_wrapper.ClusterWeights(
240230
keras.layers.Dense(8, input_shape=input_shape),
241231
number_of_clusters=2,
242-
cluster_centroids_init=CentroidInitialization.LINEAR
243232
)
244233
# Build a layer with the given shape
245234
original_layer.build(input_shape)

tensorflow_model_optimization/python/core/clustering/keras/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ py_strict_library(
2424
visibility = ["//visibility:public"],
2525
deps = [
2626
"//tensorflow_model_optimization/python/core/clustering/keras:cluster",
27+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
2728
],
2829
)

tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414
# ==============================================================================
1515
"""Experimental clustering API functions for Keras models."""
1616

17+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
1718
from tensorflow_model_optimization.python.core.clustering.keras.cluster import _cluster_weights
1819

20+
CentroidInitialization = cluster_config.CentroidInitialization
1921

20-
def cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
21-
preserve_sparsity, **kwargs):
22+
23+
def cluster_weights(
24+
to_cluster,
25+
number_of_clusters,
26+
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
27+
preserve_sparsity=False,
28+
**kwargs):
2229
"""Modify a keras layer or model to be clustered during training (experimental).
2330
2431
This function wraps a keras model or layer with clustering functionality

0 commit comments

Comments
 (0)