Skip to content

Commit c5f42ab

Browse files
committed
Addressed reviewer's comments. Added test that demonstrates that 'bias' is not clustered by default.
Change-Id: I6e21506d8c44cb6dbaa9200d1f87918df1982da9
1 parent 1155eee commit c5f42ab

File tree

4 files changed

+179
-42
lines changed

4 files changed

+179
-42
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,19 @@ py_test(
166166
)
167167

168168
py_test(
169-
name = "mnist_customerable_test",
170-
srcs = ["mnist_customerable_test.py"],
169+
name = "mnist_clusterable_layer_test",
170+
srcs = ["mnist_clusterable_layer_test.py"],
171+
python_version = "PY3",
172+
visibility = ["//visibility:public"],
173+
deps = [
174+
":cluster"
175+
# tensorflow dep1,
176+
],
177+
)
178+
179+
py_test(
180+
name = "mnist_clustering_test",
181+
srcs = ["mnist_clustering_test.py"],
171182
python_version = "PY3",
172183
visibility = ["//visibility:public"],
173184
deps = [

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

5757

58-
class MyCustomerableLayer(keras.layers.Dense,
58+
class MyClusterableLayer(keras.layers.Dense,
5959
clusterable_layer.ClusterableLayer):
6060

6161
def __init__(self, num_units):
@@ -65,7 +65,7 @@ def get_clusterable_weights(self):
6565
# Cluster kernel and bias.
6666
return [('kernel', self.kernel), ('bias', self.bias)]
6767

68-
class MyCustomerableLayerInvalid(keras.layers.Dense,
68+
class MyClusterableLayerInvalid(keras.layers.Dense,
6969
clusterable_layer.ClusterableLayer):
7070
""" This layer is invalid, because it does not provide
7171
get_clusterable_weights function.
@@ -107,8 +107,7 @@ def setUp(self):
107107
self.custom_clusterable_layer = CustomClusterableLayer(10)
108108
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
109109
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
110-
self.customerable_layer = MyCustomerableLayer(10)
111-
self.keras_custom_layer = KerasCustomLayer()
110+
self.clusterable_layer = MyClusterableLayer(10)
112111

113112
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
114113
{
@@ -225,12 +224,12 @@ def testClusterCustomNonClusterableLayer(self):
225224
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
226225
**self.params)
227226

228-
def testClusterMyCustomerableLayer(self):
227+
def testClusterMyClusterableLayer(self):
229228
# we have weights to cluster.
230-
customerable_layer = self.customerable_layer
231-
customerable_layer.build(input_shape=(10, 10))
229+
clusterable_layer = self.clusterable_layer
230+
clusterable_layer.build(input_shape=(10, 10))
232231

233-
wrapped_layer = cluster_wrapper.ClusterWeights(customerable_layer,
232+
wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
234233
**self.params)
235234

236235
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
@@ -239,40 +238,24 @@ def testKerasCustomLayerClusterable(self):
239238
"""
240239
Verifies that we can wrap keras custom layer that is customerable.
241240
"""
242-
customerable_layer = KerasCustomLayerClusterable()
243-
wrapped_layer = cluster_wrapper.ClusterWeights(customerable_layer,
241+
clusterable_layer = KerasCustomLayerClusterable()
242+
wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
244243
**self.params)
245244

246245
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
247246

248-
def testClusterMyCustomerableLayerInvalid(self):
247+
def testClusterMyClusterableLayerInvalid(self):
249248
"""
250249
Verifies that assertion is thrown when function
251250
get_clusterable_weights is not provided.
252251
"""
253252
with self.assertRaises(TypeError):
254-
MyCustomerableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
253+
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
255254

256-
def testClusterKerasCustomLayer(self):
257-
"""
258-
Verifies that attempting to cluster a keras custom layer raises
259-
an exception.
260-
"""
261-
# If layer is not built, it has not weights, so
262-
# we just skip it.
263-
keras_custom_layer = self.keras_custom_layer
264-
cluster_wrapper.ClusterWeights(keras_custom_layer,
265-
**self.params)
266-
# We need to build weights before check that clustering is not supported.
267-
keras_custom_layer.build(input_shape=(10, 10))
268-
with self.assertRaises(ValueError):
269-
cluster_wrapper.ClusterWeights(keras_custom_layer,
270-
**self.params)
271-
272-
>>>>>>> 8fe29ec... MLTOOLS-1031 Customerable layer API.
273255
@keras_parameterized.run_all_keras_modes
274256
def testClusterSequentialModelSelectively(self):
275257
clustered_model = keras.Sequential()
258+
clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
276259
clustered_model.add(self.keras_clusterable_layer)
277260
clustered_model.build(input_shape=(1, 10))
278261

tensorflow_model_optimization/python/core/clustering/keras/mnist_customerable_test.py renamed to tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Tests for a simple convnet with customerable layer on the MNIST dataset. """
15+
"""Tests for a simple convnet with clusterable layer on the MNIST dataset. """
1616

1717
import tensorflow as tf
1818

@@ -38,7 +38,7 @@ def get_clusterable_weights(self):
3838
# Cluster kernel and bias.
3939
return [('kernel', self.kernel), ('bias', self.bias)]
4040

41-
class CustomerableWeightsCA(clustering_registry.AbstractClusteringAlgorithm):
41+
class ClusterableWeightsCA(clustering_registry.AbstractClusteringAlgorithm):
4242
"""
4343
This class provided a special lookup function for the the weights 'w'.
4444
It reshapes and tile centroids the same way as the weights. This allows us
@@ -58,10 +58,10 @@ def get_pulling_indices(self, weight):
5858

5959
return pulling_indices
6060

61-
class MyCustomerableLayer(keras.layers.Layer, clusterable_layer.ClusterableLayer):
61+
class MyClusterableLayer(keras.layers.Layer, clusterable_layer.ClusterableLayer):
6262

6363
def __init__(self, units=32):
64-
super(MyCustomerableLayer, self).__init__()
64+
super(MyClusterableLayer, self).__init__()
6565
self.units = units
6666

6767
def build(self, input_shape):
@@ -87,7 +87,7 @@ def get_clusterable_algorithm(self, weight_name):
8787
""" Returns clustering algorithm for the custom weights 'w'.
8888
"""
8989
if weight_name == 'w':
90-
return CustomerableWeightsCA
90+
return ClusterableWeightsCA
9191
else:
9292
# We don't cluster other weights.
9393
return None
@@ -110,7 +110,7 @@ def _build_model():
110110

111111
def _build_model_2():
112112
"""
113-
Builds model with MyCustomerableLayer layer.
113+
Builds model with MyClusterableLayer layer.
114114
"""
115115
i = tf.keras.layers.Input(shape=(28, 28), name='input')
116116
x = tf.keras.layers.Reshape((28, 28, 1))(i)
@@ -119,7 +119,7 @@ def _build_model_2():
119119
x)
120120
x = tf.keras.layers.MaxPool2D(2, 2)(x)
121121
x = tf.keras.layers.Flatten()(x)
122-
output = MyCustomerableLayer(units=10)(x)
122+
output = MyClusterableLayer(units=10)(x)
123123

124124
model = tf.keras.Model(inputs=[i], outputs=[output])
125125
return model
@@ -220,13 +220,13 @@ def testMnistMyDenseLayer(self):
220220
self.assertLessEqual(nr_of_unique_weights, NUMBER_OF_CLUSTERS)
221221

222222
# checks 'bias' weights of the last layer: MyDenseLayer
223-
nr_of_unique_weights = _get_number_of_unique_weights(clustered_model, -1, 0)
223+
nr_of_unique_weights = _get_number_of_unique_weights(clustered_model, -1, 1)
224224
self.assertLessEqual(nr_of_unique_weights, NUMBER_OF_CLUSTERS)
225225

226-
def testMnistCustomerableLayer(self):
226+
def testMnistClusterableLayer(self):
227227
""" We test the keras custom layer with the provided
228-
clustering algorithm (see MyCustomerableLayer above).
229-
We cluster only 'w' weights and the class CustomerableWeightsCA
228+
clustering algorithm (see MyClusterableLayer above).
229+
We cluster only 'w' weights and the class ClusterableWeightsCA
230230
provides the function get_pulling_indices for the
231231
layer-out of 'w' weights.
232232
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for a simple convnet with clusterable layer on the MNIST dataset. """
16+
17+
import tensorflow as tf
18+
19+
from tensorflow_model_optimization.python.core.clustering.keras import cluster
20+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
21+
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
22+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
23+
24+
tf.random.set_seed(42)
25+
26+
keras = tf.keras
27+
28+
EPOCHS = 7
29+
EPOCHS_FINE_TUNING = 4
30+
NUMBER_OF_CLUSTERS = 8
31+
32+
def _build_model():
33+
"""
34+
Builds simple CNN model.
35+
"""
36+
i = tf.keras.layers.Input(shape=(28, 28), name='input')
37+
x = tf.keras.layers.Reshape((28, 28, 1))(i)
38+
x = tf.keras.layers.Conv2D(
39+
filters=12, kernel_size=(3, 3), activation='relu', name='conv1')(
40+
x)
41+
x = tf.keras.layers.MaxPool2D(2, 2)(x)
42+
x = tf.keras.layers.Flatten()(x)
43+
output = tf.keras.layers.Dense(units=10)(x)
44+
45+
model = tf.keras.Model(inputs=[i], outputs=[output])
46+
return model
47+
48+
def _get_dataset():
49+
mnist = tf.keras.datasets.mnist
50+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
51+
x_train, x_test = x_train / 255.0, x_test / 255.0
52+
# Use subset of 60000 examples to keep unit test speed fast.
53+
x_train = x_train[0:1000]
54+
y_train = y_train[0:1000]
55+
return (x_train, y_train), (x_test, y_test)
56+
57+
58+
def _train_model(model):
59+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
60+
61+
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
62+
63+
(x_train, y_train), _ = _get_dataset()
64+
65+
model.fit(x_train, y_train, epochs=EPOCHS)
66+
67+
def _cluster_model(model, number_of_clusters):
68+
69+
(x_train, y_train), _ = _get_dataset()
70+
71+
clustering_params = {
72+
'number_of_clusters': NUMBER_OF_CLUSTERS,
73+
'cluster_centroids_init': cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS
74+
}
75+
76+
# Cluster model
77+
clustered_model = cluster.cluster_weights(model, **clustering_params)
78+
79+
# Use smaller learning rate for fine-tuning
80+
# clustered model
81+
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
82+
83+
clustered_model.compile(
84+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
85+
optimizer=opt,
86+
metrics=['accuracy'])
87+
88+
# Fine-tune clustered model
89+
clustered_model.fit(
90+
x_train,
91+
y_train,
92+
epochs=EPOCHS_FINE_TUNING)
93+
94+
stripped_model = cluster.strip_clustering(clustered_model)
95+
stripped_model.compile(
96+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
97+
optimizer=opt,
98+
metrics=['accuracy'])
99+
100+
return stripped_model
101+
102+
def _get_number_of_unique_weights(stripped_model, layer_nr, weights_nr):
103+
weights_as_list = stripped_model.layers[layer_nr].get_weights()[weights_nr].reshape(-1,).tolist()
104+
nr_of_unique_weights = len(set(weights_as_list))
105+
106+
return nr_of_unique_weights
107+
108+
class FunctionalTest(tf.test.TestCase):
109+
110+
def testMnist(self):
111+
""" In this test we test that 'kernel' weights
112+
are clustered.
113+
"""
114+
model = _build_model()
115+
_train_model(model)
116+
117+
# Checks that number of original weights('kernel') is greater than the number of clusters
118+
nr_of_unique_weights = _get_number_of_unique_weights(model, -1, 0)
119+
self.assertGreater(nr_of_unique_weights, NUMBER_OF_CLUSTERS)
120+
121+
# Record the number of unique values of 'bias'
122+
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 1)
123+
124+
_, (x_test, y_test) = _get_dataset()
125+
126+
results_original = model.evaluate(x_test, y_test)
127+
self.assertGreater(results_original[1], 0.85)
128+
129+
clustered_model = _cluster_model(model, NUMBER_OF_CLUSTERS)
130+
131+
results = clustered_model.evaluate(x_test, y_test)
132+
133+
self.assertGreater(results[1], 0.85)
134+
135+
nr_of_unique_weights = _get_number_of_unique_weights(clustered_model, -1, 0)
136+
self.assertLessEqual(nr_of_unique_weights, NUMBER_OF_CLUSTERS)
137+
138+
# checks that we don't cluster 'bias' weights
139+
clustered_nr_of_bias_weights = _get_number_of_unique_weights(clustered_model, -1, 1)
140+
self.assertEqual(nr_of_bias_weights, clustered_nr_of_bias_weights)
141+
142+
if __name__ == '__main__':
143+
tf.test.main()

0 commit comments

Comments
 (0)