Skip to content

Commit 006e377

Browse files
Merge pull request #667 from johan-gras:feature/avg-grad
PiperOrigin-RevId: 370031366
2 parents 2a09e28 + 6a2c821 commit 006e377

File tree

15 files changed

+583
-373
lines changed

15 files changed

+583
-373
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ py_library(
8282
visibility = ["//visibility:public"],
8383
deps = [
8484
# tensorflow dep1,
85+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
8586
],
8687
)
8788

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,36 +293,25 @@ def _strip_clustering_wrapper(layer):
293293
if isinstance(layer, keras.Model):
294294
return keras.models.clone_model(
295295
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
296+
296297
elif isinstance(layer, cluster_wrapper.ClusterWeights):
297-
if not hasattr(layer.layer, '_batch_input_shape') and\
298-
hasattr(layer, '_batch_input_shape'):
299-
# pylint: disable=protected-access
300-
layer.layer._batch_input_shape = layer._batch_input_shape
298+
# Update cluster associations in order to get the latest weights
299+
layer.update_clustered_weights_associations()
300+
301+
# Construct a list of weights to initialize the clean layer
302+
updated_weights = layer.layer.get_weights() # non clusterable weights only
303+
for position_variable, weight_name in layer.position_original_weights.items():
304+
# Add the clustered weights at the correct position
305+
clustered_weight = getattr(layer.layer, weight_name)
306+
updated_weights.insert(position_variable, clustered_weight)
307+
308+
# Construct a clean layer with the updated weights
309+
clean_layer = layer.layer.from_config(layer.layer.get_config())
310+
clean_layer.build(layer.build_input_shape)
311+
clean_layer.set_weights(updated_weights)
312+
313+
return clean_layer
301314

302-
# We reset both arrays of weights, so that we can guarantee the correct
303-
# order of newly created weights
304-
# pylint: disable=protected-access
305-
layer.layer._trainable_weights = []
306-
layer.layer._non_trainable_weights = []
307-
for i in range(len(layer.restore)):
308-
# This is why we used integers as keys
309-
name, weight_name, weight = layer.restore[i]
310-
# In both cases we use k.batch_get_value since we need physical copies
311-
# of the arrays to initialize a new tensor
312-
if i in layer.gone_variables:
313-
# If the variable was removed because it was clustered, we restore it
314-
# by using updater we created earlier
315-
new_weight_value = k.batch_get_value([weight()])[0]
316-
else:
317-
# If the value was not clustered(e.g. bias), we still store a valid
318-
# reference to the tensor. We use this reference to get the value
319-
new_weight_value = k.batch_get_value([weight])[0]
320-
setattr(layer.layer,
321-
name,
322-
k.variable(new_weight_value, name=weight_name))
323-
# When all weights are filled with the values, just return the underlying
324-
# layer since it is now fully autonomous from its wrapper
325-
return layer.layer
326315
return layer
327316

328317
# Just copy the model with the right callback

tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ class CentroidInitialization(str, enum.Enum):
3131
initialize the clusters centroids.
3232
* `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm
3333
"""
34-
LINEAR = "LINEAR"
35-
RANDOM = "RANDOM"
36-
DENSITY_BASED = "DENSITY_BASED"
37-
KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS"
34+
LINEAR = "CentroidInitialization.LINEAR"
35+
RANDOM = "CentroidInitialization.RANDOM"
36+
DENSITY_BASED = "CentroidInitialization.DENSITY_BASED"
37+
KMEANS_PLUS_PLUS = "CentroidInitialization.KMEANS_PLUS_PLUS"
38+
39+
40+
class GradientAggregation(str, enum.Enum):
41+
"""Specifies how the cluster gradient should be aggregated.
42+
43+
* `SUM`: The gradient of each cluster centroid is the sum of their
44+
respective child’s weight gradient.
45+
* `AVG`: The gradient of each cluster centroid is the averaged sum of
46+
their respective child’s weight gradient.
47+
"""
48+
SUM = "GradientAggregation.SUM"
49+
AVG = "GradientAggregation.AVG"

tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def setUp(self):
4343
"cluster_centroids_init": CentroidInitialization.LINEAR
4444
}
4545

46-
4746
@parameterized.parameters(_distribution_strategies())
4847
def testClusterSimpleDenseModel(self, distribution):
4948
"""End-to-end test."""
@@ -64,7 +63,7 @@ def testClusterSimpleDenseModel(self, distribution):
6463
model.predict(np.random.rand(20, 10))
6564

6665
stripped_model = cluster.strip_clustering(model)
67-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
66+
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(-1,).tolist()
6867
unique_weights = set(weights_as_list)
6968
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
7069

@@ -87,7 +86,7 @@ def testAssociationValuesPerReplica(self, distribution):
8786
self.assertEqual(len(clusterable_weights), 1)
8887
weights_name = clusterable_weights[0][0]
8988
self.assertEqual(weights_name, 'kernel')
90-
centroids1 = l.cluster_centroids_tf[weights_name]
89+
centroids1 = l.cluster_centroids[weights_name]
9190

9291
mean_weight = tf.reduce_mean(l.layer.kernel)
9392
min_weight = tf.reduce_min(l.layer.kernel)
@@ -119,18 +118,18 @@ def update_fn(v, val):
119118
centroids1, update_fn, args=(initial_val,))
120119
l.call(tf.ones(shape=input_shape))
121120

122-
clst_indices = l.pulling_indices_tf[weights_name]
121+
clst_indices = l.pulling_indices[weights_name]
123122
per_replica = distribution.experimental_local_results(clst_indices)
124123
assert_all_cluster_indices(per_replica, 0)
125124

126125
second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
127126
aggregation=tf.VariableAggregation.MEAN)
128-
centroids2 = l.cluster_centroids_tf[weights_name]
127+
centroids2 = l.cluster_centroids[weights_name]
129128
centroids2 = distribution.extended.update(
130129
centroids2, update_fn, args=(second_val,))
131130
l.call(tf.ones(shape=input_shape))
132131

133-
clst_indices = l.pulling_indices_tf[weights_name]
132+
clst_indices = l.pulling_indices[weights_name]
134133
per_replica = distribution.experimental_local_results(clst_indices)
135134
assert_all_cluster_indices(per_replica, 1)
136135

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def testSparsityIsPreservedDuringTraining(self):
174174
original_model, **clustering_params)
175175

176176
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
177-
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
177+
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
178178
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)
179179

180180
clustered_model.compile(
@@ -185,9 +185,9 @@ def testSparsityIsPreservedDuringTraining(self):
185185
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)
186186

187187
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
188-
weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
188+
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
189189
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
190-
weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist()
190+
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(-1,).tolist()
191191
unique_weights_after_tuning = set(weights_as_list_after_tuning)
192192

193193
# Check that the null weights stayed the same before and after tuning.
@@ -299,6 +299,54 @@ def clusters_check(stripped_model):
299299

300300
self.end_to_end_testing(original_model, clusters_check)
301301

302+
@keras_parameterized.run_all_keras_modes
303+
def testWeightsAreLearningDuringClustering(self):
304+
"""Verifies that training a clustered model does update
305+
original_weights, clustered_centroids and bias."""
306+
original_model = keras.Sequential([
307+
layers.Dense(5, input_shape=(5,))
308+
])
309+
310+
clustered_model = cluster.cluster_weights(original_model, **self.params)
311+
312+
clustered_model.compile(
313+
loss=keras.losses.categorical_crossentropy,
314+
optimizer="adam",
315+
metrics=["accuracy"],
316+
)
317+
318+
class CheckWeightsCallback(keras.callbacks.Callback):
319+
def on_train_batch_begin(self, batch, logs=None):
320+
# Save weights before batch
321+
self.original_weight_kernel = (
322+
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
323+
)
324+
self.cluster_centroids_kernel = (
325+
self.model.layers[0].cluster_centroids['kernel'].numpy()
326+
)
327+
self.bias = (
328+
self.model.layers[0].layer.bias.numpy()
329+
)
330+
331+
def on_train_batch_end(self, batch, logs=None):
332+
# Check weights are different after batch
333+
assert not np.array_equal(
334+
self.original_weight_kernel,
335+
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
336+
)
337+
assert not np.array_equal(
338+
self.cluster_centroids_kernel,
339+
self.model.layers[0].cluster_centroids['kernel'].numpy()
340+
)
341+
assert not np.array_equal(
342+
self.bias,
343+
self.model.layers[0].layer.bias.numpy()
344+
)
345+
346+
clustered_model.fit(x=self.dataset_generator(),
347+
steps_per_epoch=5,
348+
callbacks=[CheckWeightsCallback()])
349+
302350

303351
if __name__ == "__main__":
304352
test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import json
1818

19+
import tempfile
20+
21+
import os
1922
from absl.testing import parameterized
2023
import tensorflow as tf
2124

@@ -282,6 +285,76 @@ def testClusterKerasCustomLayer(self):
282285
with self.assertRaises(ValueError):
283286
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
284287

288+
def testStripClusteringSequentialModelWithKernelRegularizer(self):
289+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
290+
model = keras.Sequential([
291+
layers.Dense(10, input_shape=(10,)),
292+
layers.Dense(10, kernel_regularizer=tf.keras.regularizers.L1(0.01)),
293+
])
294+
clustered_model = cluster.cluster_weights(model, **self.params)
295+
stripped_model = cluster.strip_clustering(clustered_model)
296+
# check that kernel regularizer is present in the second dense layer
297+
self.assertIsNotNone(stripped_model.layers[1].kernel_regularizer)
298+
with tempfile.TemporaryDirectory() as tmp_dir_name:
299+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
300+
stripped_model.save(keras_file, save_traces = True)
301+
302+
def testStripClusteringSequentialModelWithBiasRegularizer(self):
303+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
304+
model = keras.Sequential([
305+
layers.Dense(10, input_shape=(10,)),
306+
layers.Dense(10, bias_regularizer=tf.keras.regularizers.L1(0.01)),
307+
])
308+
clustered_model = cluster.cluster_weights(model, **self.params)
309+
stripped_model = cluster.strip_clustering(clustered_model)
310+
# check that kernel regularizer is present in the second dense layer
311+
self.assertIsNotNone(stripped_model.layers[1].bias_regularizer)
312+
with tempfile.TemporaryDirectory() as tmp_dir_name:
313+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
314+
stripped_model.save(keras_file, save_traces = True)
315+
316+
def testStripClusteringSequentialModelWithActivityRegularizer(self):
317+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
318+
model = keras.Sequential([
319+
layers.Dense(10, input_shape=(10,)),
320+
layers.Dense(10, activity_regularizer=tf.keras.regularizers.L1(0.01)),
321+
])
322+
clustered_model = cluster.cluster_weights(model, **self.params)
323+
stripped_model = cluster.strip_clustering(clustered_model)
324+
# check that kernel regularizer is present in the second dense layer
325+
self.assertIsNotNone(stripped_model.layers[1].activity_regularizer)
326+
with tempfile.TemporaryDirectory() as tmp_dir_name:
327+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
328+
stripped_model.save(keras_file, save_traces = True)
329+
330+
def testStripClusteringSequentialModelWithKernelConstraint(self):
331+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
332+
model = keras.Sequential([
333+
layers.Dense(10, input_shape=(10,)),
334+
layers.Dense(10, kernel_constraint=tf.keras.constraints.max_norm(2.)),
335+
])
336+
clustered_model = cluster.cluster_weights(model, **self.params)
337+
stripped_model = cluster.strip_clustering(clustered_model)
338+
# check that kernel regularizer is present in the second dense layer
339+
self.assertIsNotNone(stripped_model.layers[1].kernel_constraint)
340+
with tempfile.TemporaryDirectory() as tmp_dir_name:
341+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
342+
stripped_model.save(keras_file, save_traces = True)
343+
344+
def testStripClusteringSequentialModelWithBiasConstraint(self):
345+
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
346+
model = keras.Sequential([
347+
layers.Dense(10, input_shape=(10,)),
348+
layers.Dense(10, bias_constraint=tf.keras.constraints.max_norm(2.)),
349+
])
350+
clustered_model = cluster.cluster_weights(model, **self.params)
351+
stripped_model = cluster.strip_clustering(clustered_model)
352+
# check that kernel regularizer is present in the second dense layer
353+
self.assertIsNotNone(stripped_model.layers[1].bias_constraint)
354+
with tempfile.TemporaryDirectory() as tmp_dir_name:
355+
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
356+
stripped_model.save(keras_file, save_traces = True)
357+
285358
def testClusterMyClusterableLayer(self):
286359
# we have weights to cluster.
287360
layer = self.clusterable_layer
@@ -567,7 +640,7 @@ def testClusterSubclassModelAsSubmodel(self):
567640
def testStripClusteringSequentialModel(self):
568641
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
569642
model = keras.Sequential([
570-
layers.Dense(10),
643+
layers.Dense(10, input_shape=(5,)),
571644
layers.Dense(10),
572645
])
573646

@@ -610,7 +683,7 @@ def testClusterWeightsStrippedWeights(self):
610683

611684
@keras_parameterized.run_all_keras_modes
612685
def testStrippedKernel(self):
613-
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value ."""
686+
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
614687
i1 = keras.Input(shape=(1, 1, 1))
615688
x1 = layers.Conv2D(1, 1)(i1)
616689
outputs = x1
@@ -624,8 +697,7 @@ def testStrippedKernel(self):
624697

625698
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
626699
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
627-
self.assertEqual(stripped_conv2d_layer.kernel,
628-
stripped_conv2d_layer.weights[0])
700+
self.assertIn(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights)
629701

630702
@keras_parameterized.run_all_keras_modes
631703
def testStripSelectivelyClusteredFunctionalModel(self):
@@ -656,5 +728,23 @@ def testStripSelectivelyClusteredSequentialModel(self):
656728
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
657729
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
658730

731+
@keras_parameterized.run_all_keras_modes
732+
def testStripClusteringAndSetOriginalWeightsBack(self):
733+
"""Verifies that we can set_weights onto the stripped model."""
734+
model = keras.Sequential([
735+
layers.Dense(10, input_shape=(5,)),
736+
layers.Dense(10),
737+
])
738+
739+
# Save original weights
740+
original_weights = model.get_weights()
741+
742+
# Cluster and strip
743+
clustered_model = cluster.cluster_weights(model, **self.params)
744+
stripped_model = cluster.strip_clustering(clustered_model)
745+
746+
# Set back original weights onto the strip model
747+
stripped_model.set_weights(original_weights)
748+
659749
if __name__ == '__main__':
660750
test.main()

0 commit comments

Comments
 (0)