Skip to content

Commit 45b15be

Browse files
committed
changed strategy to enforce static mask in response to review comments, extended mnist test for additional checks
1 parent e60f936 commit 45b15be

File tree

4 files changed

+53
-19
lines changed

4 files changed

+53
-19
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,5 +229,6 @@ py_strict_test(
229229
":cluster",
230230
":cluster_config",
231231
# tensorflow dep1,
232+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental",
232233
],
233234
)

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def testSparsityIsPreservedDuringTraining(self):
199199
stripped_model_after_tuning, 0, 'kernel')
200200
# Check after sparsity-aware clustering, despite zero centroid can drift,
201201
# the final number of unique weights remains the same
202-
self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after)
202+
self.assertLessEqual(nr_of_unique_weights_after, nr_of_unique_weights_before)
203203
# Check that the null weights stayed the same before and after tuning.
204204
# There might be new weights that become zeros but sparsity-aware
205205
# clustering preserves the original null weights in the original positions

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(self,
107107

108108
# Stores the pairs of weight names and their respective sparsity masks
109109
self.sparsity_masks = {}
110+
self.zero_idx = {}
111+
112+
# Stores the pairs of weight names and the zero centroids
110113

111114
# Map weight names to original clusterable weights variables
112115
# Those weights will still be updated during backpropagation
@@ -199,10 +202,32 @@ def build(self, input_shape):
199202
pulling_indices, original_weight))
200203
self.sparsity_masks[weight_name] = (
201204
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32))
205+
# If the model is pruned (which we suppose), this is approximately zero
206+
self.zero_idx[weight_name] = tf.argmin(
207+
tf.abs(self.cluster_centroids[weight_name]), axis=-1)
202208

203209
def update_clustered_weights_associations(self):
204210
for weight_name, original_weight in self.original_clusterable_weights.items(
205211
):
212+
213+
if self.preserve_sparsity:
214+
# Set the smallest centroid to zero to force sparsity
215+
# and avoid extra cluster from forming
216+
zero_idx_mask = (
217+
tf.cast(tf.math.not_equal(
218+
self.cluster_centroids[weight_name],
219+
self.cluster_centroids[weight_name][self.zero_idx[weight_name]]),
220+
dtype=tf.float32)
221+
)
222+
self.cluster_centroids[weight_name].assign(
223+
tf.math.multiply(self.cluster_centroids[weight_name],
224+
zero_idx_mask))
225+
# During training, the original zero weights can drift slightly.
226+
# We want to prevent this by forcing them to stay zero at the places
227+
# where they were originally zero to begin with.
228+
original_weight = tf.math.multiply(original_weight,
229+
self.sparsity_masks[weight_name])
230+
206231
# Update pulling indices (cluster associations)
207232
pulling_indices = (
208233
self.clustering_algorithms[weight_name].get_pulling_indices(
@@ -214,15 +239,6 @@ def update_clustered_weights_associations(self):
214239
self.clustering_algorithms[weight_name].get_clustered_weight(
215240
pulling_indices, original_weight))
216241

217-
if self.preserve_sparsity:
218-
# Re-discover the sparsity masks to avoid drifting
219-
self.sparsity_masks[weight_name] = (
220-
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32)
221-
)
222-
# Apply the sparsity mask to the clustered weights
223-
clustered_weights = tf.math.multiply(clustered_weights,
224-
self.sparsity_masks[weight_name])
225-
226242
# Replace the weights with their clustered counterparts
227243
setattr(self.layer, weight_name, clustered_weights)
228244

tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# ==============================================================================
1515
"""Tests for a simple convnet with clusterable layer on the MNIST dataset."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow as tf
1819

1920
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2021
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
22+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
2123

2224
tf.random.set_seed(42)
2325

@@ -63,19 +65,21 @@ def _train_model(model):
6365
model.fit(x_train, y_train, epochs=EPOCHS)
6466

6567

66-
def _cluster_model(model, number_of_clusters):
68+
def _cluster_model(model, number_of_clusters, preserve_sparsity=False):
6769

6870
(x_train, y_train), _ = _get_dataset()
6971

7072
clustering_params = {
7173
'number_of_clusters':
7274
number_of_clusters,
7375
'cluster_centroids_init':
74-
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS
76+
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS,
77+
'preserve_sparsity':
78+
preserve_sparsity,
7579
}
7680

7781
# Cluster model
78-
clustered_model = cluster.cluster_weights(model, **clustering_params)
82+
clustered_model = experimental_cluster.cluster_weights(model, **clustering_params)
7983

8084
# Use smaller learning rate for fine-tuning
8185
# clustered model
@@ -106,13 +110,27 @@ def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):
106110

107111
return nr_of_unique_weights
108112

113+
def _deepcopy_model(model):
114+
model_copy = keras.models.clone_model(model)
115+
model_copy.set_weights(model.get_weights())
116+
return model_copy
109117

110-
class FunctionalTest(tf.test.TestCase):
118+
class FunctionalTest(tf.test.TestCase, parameterized.TestCase):
111119

112-
def testMnist(self):
113-
"""In this test we test that 'kernel' weights are clustered."""
120+
def setUp(self):
114121
model = _build_model()
115122
_train_model(model)
123+
self.model = model
124+
self.dataset = _get_dataset()
125+
126+
@parameterized.parameters(
127+
(False),
128+
(True),
129+
)
130+
def testMnist(self, preserve_sparisty):
131+
"""In this test we test that 'kernel' weights are clustered."""
132+
model = self.model
133+
_, (x_test, y_test) = self.dataset
116134

117135
# Checks that number of original weights('kernel') is greater than the
118136
# number of clusters
@@ -123,12 +141,11 @@ def testMnist(self):
123141
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 'bias')
124142
self.assertGreater(nr_of_bias_weights, NUMBER_OF_CLUSTERS)
125143

126-
_, (x_test, y_test) = _get_dataset()
127-
128144
results_original = model.evaluate(x_test, y_test)
129145
self.assertGreater(results_original[1], 0.8)
130146

131-
clustered_model = _cluster_model(model, NUMBER_OF_CLUSTERS)
147+
model_copy = _deepcopy_model(model)
148+
clustered_model = _cluster_model(model_copy, NUMBER_OF_CLUSTERS, preserve_sparisty)
132149

133150
results = clustered_model.evaluate(x_test, y_test)
134151

0 commit comments

Comments
 (0)