Skip to content

Commit 3c8ff9f

Browse files
Merge pull request #702 from TamasArm:fix_sparsity_preserve_clustering
PiperOrigin-RevId: 378078640
2 parents eb5d597 + 77d074c commit 3c8ff9f

File tree

5 files changed

+133
-72
lines changed

5 files changed

+133
-72
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ py_strict_test(
201201
# absl/testing:parameterized dep1,
202202
# numpy dep1,
203203
# tensorflow dep1,
204+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
204205
"//tensorflow_model_optimization/python/core/keras:test_utils",
205206
],
206207
)
@@ -227,6 +228,8 @@ py_strict_test(
227228
deps = [
228229
":cluster",
229230
":cluster_config",
231+
# absl/testing:parameterized dep1,
230232
# tensorflow dep1,
233+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
231234
],
232235
)

tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
# ==============================================================================
1515
"""Distributed clustering test."""
1616

17+
import itertools
18+
1719
from absl.testing import parameterized
1820
import numpy as np
1921
import tensorflow as tf
2022

2123
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2224
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2325
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
26+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
2427
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2528

2629
keras = tf.keras
@@ -31,22 +34,35 @@ def _distribution_strategies():
3134
return [tf.distribute.MirroredStrategy()]
3235

3336

37+
def _clustering_strategies():
38+
return [
39+
{
40+
'number_of_clusters': 2,
41+
'cluster_centroids_init': CentroidInitialization.LINEAR,
42+
'preserve_sparsity': False
43+
},
44+
{
45+
'number_of_clusters': 3,
46+
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
47+
'preserve_sparsity': True
48+
}
49+
]
50+
51+
3452
class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
3553
"""Distributed tests for clustering."""
3654

37-
def setUp(self):
38-
super(ClusterDistributedTest, self).setUp()
39-
self.params = {
40-
'number_of_clusters': 2,
41-
'cluster_centroids_init': CentroidInitialization.LINEAR
42-
}
43-
44-
@parameterized.parameters(_distribution_strategies())
45-
def testClusterSimpleDenseModel(self, distribution):
55+
@parameterized.parameters(
56+
*itertools.product(
57+
_distribution_strategies(),
58+
_clustering_strategies()
59+
)
60+
)
61+
def testClusterSimpleDenseModel(self, distribution, clustering):
4662
"""End-to-end test."""
4763
with distribution.scope():
48-
model = cluster.cluster_weights(
49-
keras_test_utils.build_simple_dense_model(), **self.params)
64+
model = experimental_cluster.cluster_weights(
65+
keras_test_utils.build_simple_dense_model(), **clustering)
5066
model.compile(
5167
loss='categorical_crossentropy',
5268
optimizer='sgd',
@@ -64,9 +80,11 @@ def testClusterSimpleDenseModel(self, distribution):
6480
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(
6581
-1,).tolist()
6682
unique_weights = set(weights_as_list)
67-
self.assertLessEqual(len(unique_weights), self.params['number_of_clusters'])
83+
self.assertLessEqual(len(unique_weights), clustering['number_of_clusters'])
6884

69-
@parameterized.parameters(_distribution_strategies())
85+
@parameterized.parameters(
86+
_distribution_strategies()
87+
)
7088
def testAssociationValuesPerReplica(self, distribution):
7189
"""Verifies that associations of weights are updated per replica."""
7290
assert tf.distribute.get_replica_context() is not None
@@ -76,8 +94,9 @@ def testAssociationValuesPerReplica(self, distribution):
7694
output_shape = (2, 8)
7795
l = cluster_wrapper.ClusterWeights(
7896
keras.layers.Dense(8, input_shape=input_shape),
79-
number_of_clusters=self.params['number_of_clusters'],
80-
cluster_centroids_init=self.params['cluster_centroids_init'])
97+
number_of_clusters=2,
98+
cluster_centroids_init=CentroidInitialization.LINEAR
99+
)
81100
l.build(input_shape)
82101

83102
clusterable_weights = l.layer.get_clusterable_weights()

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test):
128128
interpreter.invoke()
129129
interpreter.get_tensor(output_index)
130130

131+
@staticmethod
132+
def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):
133+
layer = stripped_model.layers[layer_nr]
134+
weight = getattr(layer, weight_name)
135+
weights_as_list = weight.numpy().flatten()
136+
nr_of_unique_weights = len(set(weights_as_list))
137+
return nr_of_unique_weights
138+
131139
@keras_parameterized.run_all_keras_modes
132140
def testValuesRemainClusteredAfterTraining(self):
133141
"""Verifies that training a clustered model does not destroy the clusters."""
@@ -152,71 +160,59 @@ def testValuesRemainClusteredAfterTraining(self):
152160

153161
@keras_parameterized.run_all_keras_modes
154162
def testSparsityIsPreservedDuringTraining(self):
155-
# Set a specific random seed to ensure that we get some null weights to
156-
# test sparsity preservation with.
157-
tf.random.set_seed(1)
163+
"""Set a specific random seed.
158164
159-
# Verifies that training a clustered model does not destroy the sparsity of
160-
# the weights.
165+
Ensures that we get some null weights to test sparsity preservation with.
166+
"""
167+
tf.random.set_seed(1)
168+
# Verifies that training a clustered model with null weights in it
169+
# does not destroy the sparsity of the weights.
161170
original_model = keras.Sequential([
162171
layers.Dense(5, input_shape=(5,)),
163-
layers.Dense(5),
172+
layers.Flatten(),
164173
])
165-
166-
# Using a mininum number of centroids to make it more likely that some
167-
# weights will be zero.
174+
# Reset the kernel weights to reflect potential zero drifting of
175+
# the cluster centroids
176+
first_layer_weights = original_model.layers[0].get_weights()
177+
first_layer_weights[0][:][0:2] = 0.0
178+
first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13]
179+
first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13]
180+
original_model.layers[0].set_weights(first_layer_weights)
168181
clustering_params = {
169-
"number_of_clusters": 3,
182+
"number_of_clusters": 6,
170183
"cluster_centroids_init": CentroidInitialization.LINEAR,
171184
"preserve_sparsity": True
172185
}
173-
174186
clustered_model = experimental_cluster.cluster_weights(
175187
original_model, **clustering_params)
176-
177188
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
178-
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
179-
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)
180-
189+
nr_of_unique_weights_before = self._get_number_of_unique_weights(
190+
stripped_model_before_tuning, 0, "kernel")
181191
clustered_model.compile(
182192
loss=keras.losses.categorical_crossentropy,
183193
optimizer="adam",
184194
metrics=["accuracy"],
185195
)
186-
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)
187-
196+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100)
188197
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
189198
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
190-
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
191-
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(
192-
-1,).tolist()
193-
unique_weights_after_tuning = set(weights_as_list_after_tuning)
194-
199+
nr_of_unique_weights_after = self._get_number_of_unique_weights(
200+
stripped_model_after_tuning, 0, "kernel")
201+
# Check after sparsity-aware clustering, despite zero centroid can drift,
202+
# the final number of unique weights remains the same
203+
self.assertLessEqual(nr_of_unique_weights_after,
204+
nr_of_unique_weights_before)
195205
# Check that the null weights stayed the same before and after tuning.
206+
# There might be new weights that become zeros but sparsity-aware
207+
# clustering preserves the original null weights in the original positions
208+
# of the weight array
196209
self.assertTrue(
197-
np.array_equal(non_zero_weight_indices_before_tuning,
198-
non_zero_weight_indices_after_tuning))
199-
210+
np.array_equal(first_layer_weights[0][:][0:2],
211+
weights_after_tuning[:][0:2]))
200212
# Check that the number of unique weights matches the number of clusters.
201213
self.assertLessEqual(
202-
len(unique_weights_after_tuning), self.params["number_of_clusters"])
203-
204-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
205-
def testEndToEndSequential(self):
206-
"""Test End to End clustering - sequential model."""
207-
original_model = keras.Sequential([
208-
layers.Dense(5, input_shape=(5,)),
209-
layers.Dense(5),
210-
])
211-
212-
def clusters_check(stripped_model):
213-
# dense layer
214-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
215-
unique_weights = set(weights_as_list)
216-
self.assertLessEqual(
217-
len(unique_weights), self.params["number_of_clusters"])
218-
219-
self.end_to_end_testing(original_model, clusters_check)
214+
nr_of_unique_weights_after,
215+
clustering_params["number_of_clusters"])
220216

221217
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
222218
def testEndToEndFunctional(self):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __init__(self,
108108
# Stores the pairs of weight names and their respective sparsity masks
109109
self.sparsity_masks = {}
110110

111+
# Stores the pairs of weight names and the zero centroids
112+
self.zero_idx = {}
113+
111114
# Map weight names to original clusterable weights variables
112115
# Those weights will still be updated during backpropagation
113116
self.original_clusterable_weights = {}
@@ -205,10 +208,33 @@ def build(self, input_shape):
205208
pulling_indices, original_weight))
206209
self.sparsity_masks[weight_name] = (
207210
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32))
211+
# If the model is pruned (which we suppose), this is approximately zero
212+
self.zero_idx[weight_name] = tf.argmin(
213+
tf.abs(self.cluster_centroids[weight_name]), axis=-1)
208214

209215
def update_clustered_weights_associations(self):
210216
for weight_name, original_weight in self.original_clusterable_weights.items(
211217
):
218+
219+
if self.preserve_sparsity:
220+
# Set the smallest centroid to zero to force sparsity
221+
# and avoid extra cluster from forming
222+
zero_idx_mask = (
223+
tf.cast(
224+
tf.math.not_equal(
225+
self.cluster_centroids[weight_name],
226+
self.cluster_centroids[weight_name][
227+
self.zero_idx[weight_name]]),
228+
dtype=tf.float32))
229+
self.cluster_centroids[weight_name].assign(
230+
tf.math.multiply(self.cluster_centroids[weight_name],
231+
zero_idx_mask))
232+
# During training, the original zero weights can drift slightly.
233+
# We want to prevent this by forcing them to stay zero at the places
234+
# where they were originally zero to begin with.
235+
original_weight = tf.math.multiply(original_weight,
236+
self.sparsity_masks[weight_name])
237+
212238
# Update pulling indices (cluster associations)
213239
pulling_indices = (
214240
self.clustering_algorithms[weight_name].get_pulling_indices(
@@ -220,11 +246,6 @@ def update_clustered_weights_associations(self):
220246
self.clustering_algorithms[weight_name].get_clustered_weight(
221247
pulling_indices, original_weight))
222248

223-
if self.preserve_sparsity:
224-
# Apply the sparsity mask to the clustered weights
225-
clustered_weights = tf.math.multiply(clustered_weights,
226-
self.sparsity_masks[weight_name])
227-
228249
# Replace the weights with their clustered counterparts
229250
self.set_weight_to_layer(weight_name, clustered_weights)
230251

tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# ==============================================================================
1515
"""Tests for a simple convnet with clusterable layer on the MNIST dataset."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow as tf
1819

1920
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2021
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
22+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
2123

2224
tf.random.set_seed(42)
2325

@@ -63,19 +65,22 @@ def _train_model(model):
6365
model.fit(x_train, y_train, epochs=EPOCHS)
6466

6567

66-
def _cluster_model(model, number_of_clusters):
68+
def _cluster_model(model, number_of_clusters, preserve_sparsity=False):
6769

6870
(x_train, y_train), _ = _get_dataset()
6971

7072
clustering_params = {
7173
'number_of_clusters':
7274
number_of_clusters,
7375
'cluster_centroids_init':
74-
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS
76+
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS,
77+
'preserve_sparsity':
78+
preserve_sparsity,
7579
}
7680

7781
# Cluster model
78-
clustered_model = cluster.cluster_weights(model, **clustering_params)
82+
clustered_model = experimental_cluster.cluster_weights(model,
83+
**clustering_params)
7984

8085
# Use smaller learning rate for fine-tuning
8186
# clustered model
@@ -107,12 +112,29 @@ def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):
107112
return nr_of_unique_weights
108113

109114

110-
class FunctionalTest(tf.test.TestCase):
115+
def _deepcopy_model(model):
116+
model_copy = keras.models.clone_model(model)
117+
model_copy.set_weights(model.get_weights())
118+
return model_copy
111119

112-
def testMnist(self):
113-
"""In this test we test that 'kernel' weights are clustered."""
120+
121+
class FunctionalTest(tf.test.TestCase, parameterized.TestCase):
122+
123+
def setUp(self):
124+
super(FunctionalTest, self).setUp()
114125
model = _build_model()
115126
_train_model(model)
127+
self.model = model
128+
self.dataset = _get_dataset()
129+
130+
@parameterized.parameters(
131+
(False),
132+
(True),
133+
)
134+
def testMnist(self, preserve_sparisty):
135+
"""In this test we test that 'kernel' weights are clustered."""
136+
model = self.model
137+
_, (x_test, y_test) = self.dataset
116138

117139
# Checks that number of original weights('kernel') is greater than the
118140
# number of clusters
@@ -123,12 +145,12 @@ def testMnist(self):
123145
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 'bias')
124146
self.assertGreater(nr_of_bias_weights, NUMBER_OF_CLUSTERS)
125147

126-
_, (x_test, y_test) = _get_dataset()
127-
128148
results_original = model.evaluate(x_test, y_test)
129149
self.assertGreater(results_original[1], 0.8)
130150

131-
clustered_model = _cluster_model(model, NUMBER_OF_CLUSTERS)
151+
model_copy = _deepcopy_model(model)
152+
clustered_model = _cluster_model(model_copy, NUMBER_OF_CLUSTERS,
153+
preserve_sparisty)
132154

133155
results = clustered_model.evaluate(x_test, y_test)
134156

0 commit comments

Comments
 (0)