Skip to content

Commit 6a2c821

Browse files
committed
[Clustering] Average Gradient Aggregation
* Add Average Gradient Aggregation parameter to the clustering API * Cleaning of the wrapper logic Change-Id: Ia7f94a1bc8de2e13663ddcf676748413ff1f9295
1 parent 2e85bc3 commit 6a2c821

File tree

14 files changed

+560
-368
lines changed

14 files changed

+560
-368
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ py_library(
8282
visibility = ["//visibility:public"],
8383
deps = [
8484
# tensorflow dep1,
85+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
8586
],
8687
)
8788

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,36 +293,25 @@ def _strip_clustering_wrapper(layer):
293293
if isinstance(layer, keras.Model):
294294
return keras.models.clone_model(
295295
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
296+
296297
elif isinstance(layer, cluster_wrapper.ClusterWeights):
297-
if not hasattr(layer.layer, '_batch_input_shape') and\
298-
hasattr(layer, '_batch_input_shape'):
299-
# pylint: disable=protected-access
300-
layer.layer._batch_input_shape = layer._batch_input_shape
298+
# Update cluster associations in order to get the latest weights
299+
layer.update_clustered_weights_associations()
300+
301+
# Construct a list of weights to initialize the clean layer
302+
updated_weights = layer.layer.get_weights() # non clusterable weights only
303+
for position_variable, weight_name in layer.position_original_weights.items():
304+
# Add the clustered weights at the correct position
305+
clustered_weight = getattr(layer.layer, weight_name)
306+
updated_weights.insert(position_variable, clustered_weight)
307+
308+
# Construct a clean layer with the updated weights
309+
clean_layer = layer.layer.from_config(layer.layer.get_config())
310+
clean_layer.build(layer.build_input_shape)
311+
clean_layer.set_weights(updated_weights)
312+
313+
return clean_layer
301314

302-
# We reset both arrays of weights, so that we can guarantee the correct
303-
# order of newly created weights
304-
# pylint: disable=protected-access
305-
layer.layer._trainable_weights = []
306-
layer.layer._non_trainable_weights = []
307-
for i in range(len(layer.restore)):
308-
# This is why we used integers as keys
309-
name, weight_name, weight = layer.restore[i]
310-
# In both cases we use k.batch_get_value since we need physical copies
311-
# of the arrays to initialize a new tensor
312-
if i in layer.gone_variables:
313-
# If the variable was removed because it was clustered, we restore it
314-
# by using updater we created earlier
315-
new_weight_value = k.batch_get_value([weight()])[0]
316-
else:
317-
# If the value was not clustered(e.g. bias), we still store a valid
318-
# reference to the tensor. We use this reference to get the value
319-
new_weight_value = k.batch_get_value([weight])[0]
320-
setattr(layer.layer,
321-
name,
322-
k.variable(new_weight_value, name=weight_name))
323-
# When all weights are filled with the values, just return the underlying
324-
# layer since it is now fully autonomous from its wrapper
325-
return layer.layer
326315
return layer
327316

328317
# Just copy the model with the right callback

tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ class CentroidInitialization(str, enum.Enum):
3131
initialize the clusters centroids.
3232
* `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm
3333
"""
34-
LINEAR = "LINEAR"
35-
RANDOM = "RANDOM"
36-
DENSITY_BASED = "DENSITY_BASED"
37-
KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS"
34+
LINEAR = "CentroidInitialization.LINEAR"
35+
RANDOM = "CentroidInitialization.RANDOM"
36+
DENSITY_BASED = "CentroidInitialization.DENSITY_BASED"
37+
KMEANS_PLUS_PLUS = "CentroidInitialization.KMEANS_PLUS_PLUS"
38+
39+
40+
class GradientAggregation(str, enum.Enum):
41+
"""Specifies how the cluster gradient should be aggregated.
42+
43+
* `SUM`: The gradient of each cluster centroid is the sum of their
44+
respective child’s weight gradient.
45+
* `AVG`: The gradient of each cluster centroid is the averaged sum of
46+
their respective child’s weight gradient.
47+
"""
48+
SUM = "GradientAggregation.SUM"
49+
AVG = "GradientAggregation.AVG"

tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def setUp(self):
4343
"cluster_centroids_init": CentroidInitialization.LINEAR
4444
}
4545

46-
4746
@parameterized.parameters(_distribution_strategies())
4847
def testClusterSimpleDenseModel(self, distribution):
4948
"""End-to-end test."""
@@ -64,7 +63,7 @@ def testClusterSimpleDenseModel(self, distribution):
6463
model.predict(np.random.rand(20, 10))
6564

6665
stripped_model = cluster.strip_clustering(model)
67-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
66+
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(-1,).tolist()
6867
unique_weights = set(weights_as_list)
6968
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
7069

@@ -87,7 +86,7 @@ def testAssociationValuesPerReplica(self, distribution):
8786
self.assertEqual(len(clusterable_weights), 1)
8887
weights_name = clusterable_weights[0][0]
8988
self.assertEqual(weights_name, 'kernel')
90-
centroids1 = l.cluster_centroids_tf[weights_name]
89+
centroids1 = l.cluster_centroids[weights_name]
9190

9291
mean_weight = tf.reduce_mean(l.layer.kernel)
9392
min_weight = tf.reduce_min(l.layer.kernel)
@@ -119,18 +118,18 @@ def update_fn(v, val):
119118
centroids1, update_fn, args=(initial_val,))
120119
l.call(tf.ones(shape=input_shape))
121120

122-
clst_indices = l.pulling_indices_tf[weights_name]
121+
clst_indices = l.pulling_indices[weights_name]
123122
per_replica = distribution.experimental_local_results(clst_indices)
124123
assert_all_cluster_indices(per_replica, 0)
125124

126125
second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
127126
aggregation=tf.VariableAggregation.MEAN)
128-
centroids2 = l.cluster_centroids_tf[weights_name]
127+
centroids2 = l.cluster_centroids[weights_name]
129128
centroids2 = distribution.extended.update(
130129
centroids2, update_fn, args=(second_val,))
131130
l.call(tf.ones(shape=input_shape))
132131

133-
clst_indices = l.pulling_indices_tf[weights_name]
132+
clst_indices = l.pulling_indices[weights_name]
134133
per_replica = distribution.experimental_local_results(clst_indices)
135134
assert_all_cluster_indices(per_replica, 1)
136135

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def testSparsityIsPreservedDuringTraining(self):
174174
original_model, **clustering_params)
175175

176176
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
177-
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
177+
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
178178
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)
179179

180180
clustered_model.compile(
@@ -185,9 +185,9 @@ def testSparsityIsPreservedDuringTraining(self):
185185
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)
186186

187187
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
188-
weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
188+
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
189189
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
190-
weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist()
190+
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(-1,).tolist()
191191
unique_weights_after_tuning = set(weights_as_list_after_tuning)
192192

193193
# Check that the null weights stayed the same before and after tuning.
@@ -299,6 +299,54 @@ def clusters_check(stripped_model):
299299

300300
self.end_to_end_testing(original_model, clusters_check)
301301

302+
@keras_parameterized.run_all_keras_modes
303+
def testWeightsAreLearningDuringClustering(self):
304+
"""Verifies that training a clustered model does update
305+
original_weights, clustered_centroids and bias."""
306+
original_model = keras.Sequential([
307+
layers.Dense(5, input_shape=(5,))
308+
])
309+
310+
clustered_model = cluster.cluster_weights(original_model, **self.params)
311+
312+
clustered_model.compile(
313+
loss=keras.losses.categorical_crossentropy,
314+
optimizer="adam",
315+
metrics=["accuracy"],
316+
)
317+
318+
class CheckWeightsCallback(keras.callbacks.Callback):
319+
def on_train_batch_begin(self, batch, logs=None):
320+
# Save weights before batch
321+
self.original_weight_kernel = (
322+
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
323+
)
324+
self.cluster_centroids_kernel = (
325+
self.model.layers[0].cluster_centroids['kernel'].numpy()
326+
)
327+
self.bias = (
328+
self.model.layers[0].layer.bias.numpy()
329+
)
330+
331+
def on_train_batch_end(self, batch, logs=None):
332+
# Check weights are different after batch
333+
assert not np.array_equal(
334+
self.original_weight_kernel,
335+
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
336+
)
337+
assert not np.array_equal(
338+
self.cluster_centroids_kernel,
339+
self.model.layers[0].cluster_centroids['kernel'].numpy()
340+
)
341+
assert not np.array_equal(
342+
self.bias,
343+
self.model.layers[0].layer.bias.numpy()
344+
)
345+
346+
clustered_model.fit(x=self.dataset_generator(),
347+
steps_per_epoch=5,
348+
callbacks=[CheckWeightsCallback()])
349+
302350

303351
if __name__ == "__main__":
304352
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import json
1818

19+
import tempfile
20+
21+
import os
1922
from absl.testing import parameterized
2023
import tensorflow as tf
2124

@@ -253,6 +256,76 @@ def testClusterKerasCustomLayer(self):
253256
with self.assertRaises(ValueError):
254257
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
255258

259+
def testStripClusteringSequentialModelWithKernelRegularizer(self):
260+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
261+
model = keras.Sequential([
262+
layers.Dense(10, input_shape=(10,)),
263+
layers.Dense(10, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
264+
])
265+
clustered_model = cluster.cluster_weights(model, **self.params)
266+
stripped_model = cluster.strip_clustering(clustered_model)
267+
# check that kernel regularizer is present in the second dense layer
268+
self.assertIsNotNone(stripped_model.layers[1].kernel_regularizer)
269+
with tempfile.TemporaryDirectory() as tmp_dir_name:
270+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
271+
stripped_model.save(keras_file, save_traces = True)
272+
273+
def testStripClusteringSequentialModelWithBiasRegularizer(self):
274+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
275+
model = keras.Sequential([
276+
layers.Dense(10, input_shape=(10,)),
277+
layers.Dense(10, bias_regularizer=tf.keras.regularizers.L1(0.01)),
278+
])
279+
clustered_model = cluster.cluster_weights(model, **self.params)
280+
stripped_model = cluster.strip_clustering(clustered_model)
281+
# check that kernel regularizer is present in the second dense layer
282+
self.assertIsNotNone(stripped_model.layers[1].bias_regularizer)
283+
with tempfile.TemporaryDirectory() as tmp_dir_name:
284+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
285+
stripped_model.save(keras_file, save_traces = True)
286+
287+
def testStripClusteringSequentialModelWithActivityRegularizer(self):
288+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
289+
model = keras.Sequential([
290+
layers.Dense(10, input_shape=(10,)),
291+
layers.Dense(10, activity_regularizer=tf.keras.regularizers.L1(0.01)),
292+
])
293+
clustered_model = cluster.cluster_weights(model, **self.params)
294+
stripped_model = cluster.strip_clustering(clustered_model)
295+
# check that kernel regularizer is present in the second dense layer
296+
self.assertIsNotNone(stripped_model.layers[1].activity_regularizer)
297+
with tempfile.TemporaryDirectory() as tmp_dir_name:
298+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
299+
stripped_model.save(keras_file, save_traces = True)
300+
301+
def testStripClusteringSequentialModelWithKernelConstraint(self):
302+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
303+
model = keras.Sequential([
304+
layers.Dense(10, input_shape=(10,)),
305+
layers.Dense(10, kernel_constraint=tf.keras.constraints.max_norm(2.)),
306+
])
307+
clustered_model = cluster.cluster_weights(model, **self.params)
308+
stripped_model = cluster.strip_clustering(clustered_model)
309+
# check that kernel regularizer is present in the second dense layer
310+
self.assertIsNotNone(stripped_model.layers[1].kernel_constraint)
311+
with tempfile.TemporaryDirectory() as tmp_dir_name:
312+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
313+
stripped_model.save(keras_file, save_traces = True)
314+
315+
def testStripClusteringSequentialModelWithBiasConstraint(self):
316+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
317+
model = keras.Sequential([
318+
layers.Dense(10, input_shape=(10,)),
319+
layers.Dense(10, bias_constraint=tf.keras.constraints.max_norm(2.)),
320+
])
321+
clustered_model = cluster.cluster_weights(model, **self.params)
322+
stripped_model = cluster.strip_clustering(clustered_model)
323+
# check that kernel regularizer is present in the second dense layer
324+
self.assertIsNotNone(stripped_model.layers[1].bias_constraint)
325+
with tempfile.TemporaryDirectory() as tmp_dir_name:
326+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
327+
stripped_model.save(keras_file, save_traces = True)
328+
256329
def testClusterMyClusterableLayer(self):
257330
# we have weights to cluster.
258331
layer = self.clusterable_layer
@@ -539,7 +612,7 @@ def testClusterSubclassModelAsSubmodel(self):
539612
def testStripClusteringSequentialModel(self):
540613
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
541614
model = keras.Sequential([
542-
layers.Dense(10),
615+
layers.Dense(10, input_shape=(5,)),
543616
layers.Dense(10),
544617
])
545618

@@ -582,7 +655,7 @@ def testClusterWeightsStrippedWeights(self):
582655

583656
@keras_parameterized.run_all_keras_modes
584657
def testStrippedKernel(self):
585-
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value ."""
658+
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
586659
i1 = keras.Input(shape=(1, 1, 1))
587660
x1 = layers.Conv2D(1, 1)(i1)
588661
outputs = x1
@@ -596,8 +669,7 @@ def testStrippedKernel(self):
596669

597670
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
598671
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
599-
self.assertEqual(stripped_conv2d_layer.kernel,
600-
stripped_conv2d_layer.weights[0])
672+
self.assertIn(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights)
601673

602674
@keras_parameterized.run_all_keras_modes
603675
def testStripSelectivelyClusteredFunctionalModel(self):
@@ -628,5 +700,23 @@ def testStripSelectivelyClusteredSequentialModel(self):
628700
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
629701
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
630702

703+
@keras_parameterized.run_all_keras_modes
704+
def testStripClusteringAndSetOriginalWeightsBack(self):
705+
"""Verifies that we can set_weights onto the stripped model."""
706+
model = keras.Sequential([
707+
layers.Dense(10, input_shape=(5,)),
708+
layers.Dense(10),
709+
])
710+
711+
# Save original weights
712+
original_weights = model.get_weights()
713+
714+
# Cluster and strip
715+
clustered_model = cluster.cluster_weights(model, **self.params)
716+
stripped_model = cluster.strip_clustering(clustered_model)
717+
718+
# Set back original weights onto the strip model
719+
stripped_model.set_weights(original_weights)
720+
631721
if __name__ == '__main__':
632722
test.main()

0 commit comments

Comments
 (0)