Skip to content

Commit 6f51b3d

Browse files
Merge pull request #553 from wwwind:clustering_deep_layers
PiperOrigin-RevId: 335505128
2 parents d1a23df + 39089ce commit 6f51b3d

File tree

3 files changed

+151
-30
lines changed

3 files changed

+151
-30
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ def cluster_weights(to_cluster,
126126
format(cluster_centroids_init))
127127

128128
def _add_clustering_wrapper(layer):
129+
130+
if (isinstance(layer, keras.Model)):
131+
# Check whether the model is a subclass.
132+
# NB: This check is copied from keras.py file in tensorflow.
133+
# There is no available public API to do this check.
134+
if (not layer._is_graph_network and
135+
not isinstance(layer, keras.models.Sequential)):
136+
raise ValueError("Subclassed models are not supported currently.")
137+
138+
return keras.models.clone_model(layer,
139+
input_tensors=None,
140+
clone_function=_add_clustering_wrapper)
129141
if isinstance(layer, cluster_wrapper.ClusterWeights):
130142
return layer
131143
if isinstance(layer, InputLayer):
@@ -185,7 +197,11 @@ def strip_clustering(model):
185197
'Expected model to be a `tf.keras.Model` instance but got: ', model)
186198

187199
def _strip_clustering_wrapper(layer):
188-
if isinstance(layer, cluster_wrapper.ClusterWeights):
200+
if isinstance(layer, keras.Model):
201+
return keras.models.clone_model(layer,
202+
input_tensors=None,
203+
clone_function=_strip_clustering_wrapper)
204+
elif isinstance(layer, cluster_wrapper.ClusterWeights):
189205
if not hasattr(layer.layer, '_batch_input_shape') and\
190206
hasattr(layer, '_batch_input_shape'):
191207
layer.layer._batch_input_shape = layer._batch_input_shape

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 120 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,57 @@ def setUp(self):
4343
}
4444

4545
self.x_train = np.array(
46-
[[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]],
46+
[[0.0, 1.0, 2.0, 3.0, 4.0], [2.0, 0.0, 2.0, 3.0, 4.0], [0.0, 3.0, 2.0, 3.0, 4.0],
47+
[4.0, 1.0, 2.0, 3.0, 4.0], [5.0, 1.0, 2.0, 3.0, 4.0]],
4748
dtype="float32",
4849
)
4950

5051
self.y_train = np.array(
51-
[[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
52+
[[0.0, 1.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0],
53+
[0.0, 1.0, 2.0, 3.0, 4.0], [0.0, 1.0, 2.0, 3.0, 4.0]],
54+
dtype="float32",
55+
)
56+
57+
self.x_test = np.array(
58+
[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0], [1.0, 2.0, 3.0, 4.0, 5.0],
59+
[6.0, 1.0, 2.0, 3.0, 4.0], [9.0, 1.0, 0.0, 3.0, 0.0]],
5260
dtype="float32",
5361
)
5462

5563
def dataset_generator(self):
5664
for x, y in zip(self.x_train, self.y_train):
5765
yield np.array([x]), np.array([y])
5866

67+
def end_to_end_testing(self, original_model, clusters_check=None):
68+
"""Test End to End clustering."""
69+
70+
clustered_model = cluster.cluster_weights(original_model, **self.params)
71+
72+
clustered_model.compile(
73+
loss=keras.losses.categorical_crossentropy,
74+
optimizer="adam",
75+
metrics=["accuracy"],
76+
)
77+
78+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
79+
stripped_model = cluster.strip_clustering(clustered_model)
80+
if clusters_check is not None:
81+
clusters_check(stripped_model)
82+
83+
_, tflite_file = tempfile.mkstemp(".tflite")
84+
_, keras_file = tempfile.mkstemp(".h5")
85+
86+
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
87+
tflite_model = converter.convert()
88+
89+
with open(tflite_file, "wb") as f:
90+
f.write(tflite_model)
91+
92+
self._verify_tflite(tflite_file, self.x_test)
93+
94+
os.remove(keras_file)
95+
os.remove(tflite_file)
96+
5997
@staticmethod
6098
def _verify_tflite(tflite_file, x_test):
6199
interpreter = tf.lite.Interpreter(model_path=tflite_file)
@@ -72,8 +110,8 @@ def _verify_tflite(tflite_file, x_test):
72110
def testValuesRemainClusteredAfterTraining(self):
73111
"""Verifies that training a clustered model does not destroy the clusters."""
74112
original_model = keras.Sequential([
75-
layers.Dense(2, input_shape=(2,)),
76-
layers.Dense(2),
113+
layers.Dense(5, input_shape=(5,)),
114+
layers.Dense(5),
77115
])
78116

79117
clustered_model = cluster.cluster_weights(original_model, **self.params)
@@ -91,42 +129,95 @@ def testValuesRemainClusteredAfterTraining(self):
91129
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
92130

93131
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
94-
def testEndToEnd(self):
95-
"""Test End to End clustering."""
132+
def testEndToEndSequential(self):
133+
"""Test End to End clustering - sequential model."""
96134
original_model = keras.Sequential([
97-
layers.Dense(2, input_shape=(2,)),
98-
layers.Dense(2),
135+
layers.Dense(5, input_shape=(5,)),
136+
layers.Dense(5),
99137
])
100138

101-
clustered_model = cluster.cluster_weights(original_model, **self.params)
139+
def clusters_check(stripped_model):
140+
# dense layer
141+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
142+
unique_weights = set(weights_as_list)
143+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
102144

103-
clustered_model.compile(
104-
loss=keras.losses.categorical_crossentropy,
105-
optimizer="adam",
106-
metrics=["accuracy"],
107-
)
145+
self.end_to_end_testing(original_model, clusters_check)
108146

109-
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
110-
stripped_model = cluster.strip_clustering(clustered_model)
147+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
148+
def testEndToEndFunctional(self):
149+
"""Test End to End clustering - functional model."""
150+
inputs = keras.layers.Input(shape=(5,))
151+
layer1 = keras.layers.Dense(5)(inputs)
152+
layer2 = keras.layers.Dense(5)(layer1)
153+
original_model = keras.Model(inputs=inputs, outputs=layer2)
111154

112-
_, tflite_file = tempfile.mkstemp(".tflite")
113-
_, keras_file = tempfile.mkstemp(".h5")
155+
def clusters_check(stripped_model):
156+
# First dense layer
157+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
158+
unique_weights = set(weights_as_list)
159+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
114160

115-
if not compat.is_v1_apis():
116-
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
117-
else:
118-
tf.keras.models.save_model(stripped_model, keras_file)
119-
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
161+
self.end_to_end_testing(original_model, clusters_check)
120162

121-
tflite_model = converter.convert()
122-
with open(tflite_file, "wb") as f:
123-
f.write(tflite_model)
163+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
164+
def testEndToEndDeepLayer(self):
165+
"""Test End to End clustering for the model with deep layer."""
166+
internal_model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(5,))])
167+
original_model = keras.Sequential([
168+
internal_model,
169+
layers.Dense(5),
170+
])
124171

125-
self._verify_tflite(tflite_file, self.x_train)
172+
def clusters_check(stripped_model):
173+
# inner dense layer
174+
weights_as_list = stripped_model._layers[1]._layers[1].trainable_weights[0].\
175+
numpy().flatten()
176+
unique_weights = set(weights_as_list)
177+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
126178

127-
os.remove(keras_file)
128-
os.remove(tflite_file)
179+
# outer dense layer
180+
weights_as_list = stripped_model._layers[2].trainable_weights[0].\
181+
numpy().flatten()
182+
unique_weights = set(weights_as_list)
183+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
184+
185+
self.end_to_end_testing(original_model, clusters_check)
186+
187+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
188+
def testEndToEndDeepLayer2(self):
189+
"""Test End to End clustering for the model with 2 deep layers."""
190+
internal_model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(5,))])
191+
intermediate_model = keras.Sequential([
192+
internal_model,
193+
layers.Dense(5),
194+
])
195+
original_model = keras.Sequential([
196+
intermediate_model,
197+
layers.Dense(5),
198+
])
129199

200+
def clusters_check(stripped_model):
201+
# first inner dense layer
202+
weights_as_list = stripped_model._layers[1]._layers[1].trainable_weights[0].\
203+
numpy().flatten()
204+
unique_weights = set(weights_as_list)
205+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
206+
207+
# second inner dense layer
208+
weights_as_list = stripped_model._layers[1]._layers[1]._layers[1].\
209+
trainable_weights[0].\
210+
numpy().flatten()
211+
unique_weights = set(weights_as_list)
212+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
213+
214+
# outer dense layer
215+
weights_as_list = stripped_model._layers[2].trainable_weights[0].\
216+
numpy().flatten()
217+
unique_weights = set(weights_as_list)
218+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
219+
220+
self.end_to_end_testing(original_model, clusters_check)
130221

131222
if __name__ == "__main__":
132223
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,20 @@ def testClusterSubclassModel(self):
424424
with self.assertRaises(ValueError):
425425
_ = cluster.cluster_weights(model, **self.params)
426426

427+
@keras_parameterized.run_all_keras_modes
428+
def testClusterSubclassModelAsSubmodel(self):
429+
"""
430+
Verifies that attempting to cluster a model with submodel
431+
that is a subclass throws an exception.
432+
"""
433+
model_subclass = TestModel()
434+
model = keras.Sequential([
435+
layers.Dense(10),
436+
model_subclass
437+
])
438+
with self.assertRaisesRegexp(ValueError, "Subclassed models.*"):
439+
_ = cluster.cluster_weights(model, **self.params)
440+
427441
@keras_parameterized.run_all_keras_modes
428442
def testStripClusteringSequentialModel(self):
429443
"""

0 commit comments

Comments
 (0)