Skip to content

Commit e60f936

Browse files
committed
Fix for sparsity-preserving clustering
* modified unit tests to more extensive check of number of unique weights after sparsity preserve clustering
1 parent 28b68e3 commit e60f936

File tree

1 file changed

+37
-43
lines changed

1 file changed

+37
-43
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test):
128128
interpreter.invoke()
129129
interpreter.get_tensor(output_index)
130130

131+
@staticmethod
132+
def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):
133+
layer = stripped_model.layers[layer_nr]
134+
weight = getattr(layer, weight_name)
135+
weights_as_list = weight.numpy().flatten()
136+
nr_of_unique_weights = len(set(weights_as_list))
137+
return nr_of_unique_weights
138+
131139
@keras_parameterized.run_all_keras_modes
132140
def testValuesRemainClusteredAfterTraining(self):
133141
"""Verifies that training a clustered model does not destroy the clusters."""
@@ -150,73 +158,59 @@ def testValuesRemainClusteredAfterTraining(self):
150158
unique_weights = set(weights_as_list)
151159
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
152160

161+
153162
@keras_parameterized.run_all_keras_modes
154163
def testSparsityIsPreservedDuringTraining(self):
155-
# Set a specific random seed to ensure that we get some null weights to
156-
# test sparsity preservation with.
164+
"""Set a specific random seed to ensure that we get some null weights
165+
to test sparsity preservation with."""
157166
tf.random.set_seed(1)
158-
159-
# Verifies that training a clustered model does not destroy the sparsity of
160-
# the weights.
167+
# Verifies that training a clustered model with null weights in it
168+
# does not destroy the sparsity of the weights.
161169
original_model = keras.Sequential([
162170
layers.Dense(5, input_shape=(5,)),
163-
layers.Dense(5),
171+
layers.Flatten(),
164172
])
165-
166-
# Using a mininum number of centroids to make it more likely that some
167-
# weights will be zero.
173+
# Reset the kernel weights to reflect potential zero drifting of
174+
# the cluster centroids
175+
first_layer_weights = original_model.layers[0].get_weights()
176+
first_layer_weights[0][:][0:2] = 0.0
177+
first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13]
178+
first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13]
179+
original_model.layers[0].set_weights(first_layer_weights)
168180
clustering_params = {
169-
"number_of_clusters": 3,
181+
"number_of_clusters": 6,
170182
"cluster_centroids_init": CentroidInitialization.LINEAR,
171183
"preserve_sparsity": True
172184
}
173-
174185
clustered_model = experimental_cluster.cluster_weights(
175186
original_model, **clustering_params)
176-
177187
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
178-
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
179-
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)
180-
188+
nr_of_unique_weights_before = self._get_number_of_unique_weights(
189+
stripped_model_before_tuning, 0, 'kernel')
181190
clustered_model.compile(
182191
loss=keras.losses.categorical_crossentropy,
183192
optimizer="adam",
184193
metrics=["accuracy"],
185194
)
186-
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)
187-
195+
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100)
188196
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
189197
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
190-
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
191-
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(
192-
-1,).tolist()
193-
unique_weights_after_tuning = set(weights_as_list_after_tuning)
194-
198+
nr_of_unique_weights_after = self._get_number_of_unique_weights(
199+
stripped_model_after_tuning, 0, 'kernel')
200+
# Check after sparsity-aware clustering, despite zero centroid can drift,
201+
# the final number of unique weights remains the same
202+
self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after)
195203
# Check that the null weights stayed the same before and after tuning.
204+
# There might be new weights that become zeros but sparsity-aware
205+
# clustering preserves the original null weights in the original positions
206+
# of the weight array
196207
self.assertTrue(
197-
np.array_equal(non_zero_weight_indices_before_tuning,
198-
non_zero_weight_indices_after_tuning))
199-
208+
np.array_equal(first_layer_weights[0][:][0:2],
209+
weights_after_tuning[:][0:2]))
200210
# Check that the number of unique weights matches the number of clusters.
201211
self.assertLessEqual(
202-
len(unique_weights_after_tuning), self.params["number_of_clusters"])
203-
204-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
205-
def testEndToEndSequential(self):
206-
"""Test End to End clustering - sequential model."""
207-
original_model = keras.Sequential([
208-
layers.Dense(5, input_shape=(5,)),
209-
layers.Dense(5),
210-
])
211-
212-
def clusters_check(stripped_model):
213-
# dense layer
214-
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
215-
unique_weights = set(weights_as_list)
216-
self.assertLessEqual(
217-
len(unique_weights), self.params["number_of_clusters"])
218-
219-
self.end_to_end_testing(original_model, clusters_check)
212+
nr_of_unique_weights_after,
213+
clustering_params["number_of_clusters"])
220214

221215
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
222216
def testEndToEndFunctional(self):

0 commit comments

Comments
 (0)