Skip to content

Commit 0fdaf14

Browse files
committed
Add Anti Zero-Drift functionality for Sparsity-Aware clustering
* Implemented the zero-centroid initialization for all clustering methods * Implemented the sparsity masks for forward and backward propagation * Added preserve_sparsity class member to ClusterWeights to make sparsity preservation optional for all clustering methods * Refactored AbstractCentroidsInitialisation to include zero-centroid initialization for all init types * Added unit tests around the new changes
1 parent 6b32758 commit 0fdaf14

File tree

7 files changed

+509
-68
lines changed

7 files changed

+509
-68
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def cluster_scope():
5656
def cluster_weights(to_cluster,
5757
number_of_clusters,
5858
cluster_centroids_init,
59+
preserve_sparsity=False,
5960
**kwargs):
6061
"""Modify a keras layer or model to be clustered during training.
6162
@@ -108,8 +109,19 @@ def cluster_weights(to_cluster,
108109
number_of_clusters: the number of cluster centroids to form when
109110
clustering a layer/model. For example, if number_of_clusters=8 then only
110111
8 unique values will be used in each weight array.
111-
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
112-
instance that determines how the cluster centroids will be initialized.
112+
cluster_centroids_init: enum value that determines how the cluster
113+
centroids will be initialized.
114+
Can have following values:
115+
1. RANDOM : centroids are sampled using the uniform distribution
116+
between the minimum and maximum weight values in a given layer
117+
2. DENSITY_BASED : density-based sampling. First, cumulative
118+
distribution function is built for weights, then y-axis is evenly
119+
spaced into number_of_clusters regions. After this the corresponding x
120+
values are obtained and used to initialize clusters centroids.
121+
3. LINEAR : cluster centroids are evenly spaced between the minimum
122+
and maximum values of a given weight
123+
preserve_sparsity: optional boolean value that determines whether or not
124+
sparsity preservation will be enforced during training
113125
**kwargs: Additional keyword arguments to be passed to the keras layer.
114126
Ignored when to_cluster is not a keras layer.
115127
@@ -146,6 +158,7 @@ def _add_clustering_wrapper(layer):
146158
return cluster_wrapper.ClusterWeights(layer,
147159
number_of_clusters,
148160
cluster_centroids_init,
161+
preserve_sparsity,
149162
**kwargs)
150163

151164
def _wrap_list(layers):

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,26 @@ def setUp(self):
6060
dtype="float32",
6161
)
6262

63+
self.x_train2 = np.array(
64+
[[0.0, 1.0, 2.0, 3.0, 4.0], [2.0, 0.0, 2.0, 3.0, 4.0], [0.0, 3.0, 2.0, 3.0, 4.0],
65+
[4.0, 1.0, 2.0, 3.0, 4.0], [5.0, 1.0, 2.0, 3.0, 4.0]],
66+
dtype="float32",
67+
)
68+
69+
self.y_train2 = np.array(
70+
[[0.0, 1.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0],
71+
[0.0, 1.0, 2.0, 3.0, 4.0], [0.0, 1.0, 2.0, 3.0, 4.0]],
72+
dtype="float32",
73+
)
74+
6375
def dataset_generator(self):
6476
for x, y in zip(self.x_train, self.y_train):
6577
yield np.array([x]), np.array([y])
6678

79+
def dataset_generator2(self):
80+
for x, y in zip(self.x_train2, self.y_train2):
81+
yield np.array([x]), np.array([y])
82+
6783
def end_to_end_testing(self, original_model, clusters_check=None):
6884
"""Test End to End clustering."""
6985

@@ -128,6 +144,47 @@ def testValuesRemainClusteredAfterTraining(self):
128144
unique_weights = set(weights_as_list)
129145
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
130146

147+
@keras_parameterized.run_all_keras_modes
148+
def testSparsityIsPreservedDuringTraining(self):
149+
"""Verifies that training a clustered model does not destroy the sparsity of the weights."""
150+
original_model = keras.Sequential([
151+
layers.Dense(5, input_shape=(5,)),
152+
layers.Dense(5),
153+
])
154+
155+
"""Using a mininum number of centroids to make it more likely that some weights will be zero."""
156+
clustering_params = {
157+
"number_of_clusters": 3,
158+
"cluster_centroids_init": CentroidInitialization.LINEAR,
159+
"preserve_sparsity": True
160+
}
161+
162+
clustered_model = cluster.cluster_weights(original_model, **clustering_params)
163+
164+
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
165+
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
166+
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)
167+
168+
clustered_model.compile(
169+
loss=keras.losses.categorical_crossentropy,
170+
optimizer="adam",
171+
metrics=["accuracy"],
172+
)
173+
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)
174+
175+
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
176+
weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
177+
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
178+
weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist()
179+
unique_weights_after_tuning = set(weights_as_list_after_tuning)
180+
181+
"""Check that the null weights stayed the same before and after tuning."""
182+
self.assertTrue(np.array_equal(non_zero_weight_indices_before_tuning,
183+
non_zero_weight_indices_after_tuning))
184+
185+
"""Check that the number of unique weights matches the number of clusters."""
186+
self.assertLessEqual(len(unique_weights_after_tuning), self.params["number_of_clusters"])
187+
131188
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
132189
def testEndToEndSequential(self):
133190
"""Test End to End clustering - sequential model."""

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,18 @@ def testClusterKerasClusterableLayer(self):
112112

113113
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
114114

115+
@keras_parameterized.run_all_keras_modes
116+
def testClusterKerasClusterableLayerWithSparsityPreservation(self):
117+
"""
118+
Verifies that a built-in keras layer marked as clusterable is being
119+
clustered correctly when sparsity preservation is enabled.
120+
"""
121+
preserve_sparsity_params = { 'preserve_sparsity': True }
122+
params = { **self.params, **preserve_sparsity_params }
123+
wrapped_layer = cluster.cluster_weights(self.keras_clusterable_layer, **params)
124+
125+
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
126+
115127
@keras_parameterized.run_all_keras_modes
116128
def testClusterKerasNonClusterableLayer(self):
117129
"""
@@ -164,6 +176,22 @@ def testClusterCustomClusterableLayer(self):
164176
self.assertEqual([('kernel', wrapped_layer.layer.kernel)],
165177
wrapped_layer.layer.get_clusterable_weights())
166178

179+
@keras_parameterized.run_all_keras_modes
180+
def testClusterCustomClusterableLayerWithSparsityPreservation(self):
181+
"""
182+
Verifies that a custom clusterable layer is being clustered correctly
183+
when sparsity preservation is enabled.
184+
"""
185+
preserve_sparsity_params = { 'preserve_sparsity': True }
186+
params = { **self.params, **preserve_sparsity_params }
187+
wrapped_layer = cluster.cluster_weights(self.custom_clusterable_layer, **params)
188+
self.model.add(wrapped_layer)
189+
self.model.build(input_shape=(10, 1))
190+
191+
self._validate_clustered_layer(self.custom_clusterable_layer, wrapped_layer)
192+
self.assertEqual([('kernel', wrapped_layer.layer.kernel)],
193+
wrapped_layer.layer.get_clusterable_weights())
194+
167195
def testClusterCustomNonClusterableLayer(self):
168196
"""
169197
Verifies that attempting to cluster a custom non-clusterable layer raises
@@ -193,6 +221,22 @@ def testClusterSequentialModelSelectively(self):
193221
self.assertIsInstance(clustered_model.layers[0], cluster_wrapper.ClusterWeights)
194222
self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights)
195223

224+
@keras_parameterized.run_all_keras_modes
225+
def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
226+
"""
227+
Verifies that layers within a sequential model can be clustered
228+
selectively when sparsity preservation is enabled.
229+
"""
230+
preserve_sparsity_params = { 'preserve_sparsity': True }
231+
params = { **self.params, **preserve_sparsity_params }
232+
clustered_model = keras.Sequential()
233+
clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **params))
234+
clustered_model.add(self.keras_clusterable_layer)
235+
clustered_model.build(input_shape=(1, 10))
236+
237+
self.assertIsInstance(clustered_model.layers[0], cluster_wrapper.ClusterWeights)
238+
self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights)
239+
196240
@keras_parameterized.run_all_keras_modes
197241
def testClusterFunctionalModelSelectively(self):
198242
"""
@@ -209,6 +253,24 @@ def testClusterFunctionalModelSelectively(self):
209253
self.assertIsInstance(clustered_model.layers[2], cluster_wrapper.ClusterWeights)
210254
self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights)
211255

256+
@keras_parameterized.run_all_keras_modes
257+
def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
258+
"""
259+
Verifies that layers within a functional model can be clustered
260+
selectively when sparsity preservation is enabled.
261+
"""
262+
preserve_sparsity_params = { 'preserve_sparsity': True }
263+
params = { **self.params, **preserve_sparsity_params }
264+
i1 = keras.Input(shape=(10,))
265+
i2 = keras.Input(shape=(10,))
266+
x1 = cluster.cluster_weights(layers.Dense(10), **params)(i1)
267+
x2 = layers.Dense(10)(i2)
268+
outputs = layers.Add()([x1, x2])
269+
clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs)
270+
271+
self.assertIsInstance(clustered_model.layers[2], cluster_wrapper.ClusterWeights)
272+
self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights)
273+
212274
@keras_parameterized.run_all_keras_modes
213275
def testClusterModelValidLayersSuccessful(self):
214276
"""
@@ -227,6 +289,26 @@ def testClusterModelValidLayersSuccessful(self):
227289
for layer, clustered_layer in zip(model.layers, clustered_model.layers):
228290
self._validate_clustered_layer(layer, clustered_layer)
229291

292+
@keras_parameterized.run_all_keras_modes
293+
def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self):
294+
"""
295+
Verifies that clustering a sequential model results in all clusterable
296+
layers within the model being clustered when sparsity preservation is enabled.
297+
"""
298+
preserve_sparsity_params = { 'preserve_sparsity': True }
299+
params = { **self.params, **preserve_sparsity_params }
300+
model = keras.Sequential([
301+
self.keras_clusterable_layer,
302+
self.keras_non_clusterable_layer,
303+
self.custom_clusterable_layer
304+
])
305+
clustered_model = cluster.cluster_weights(model, **params)
306+
clustered_model.build(input_shape=(1, 28, 28, 1))
307+
308+
self.assertEqual(len(model.layers), len(clustered_model.layers))
309+
for layer, clustered_layer in zip(model.layers, clustered_model.layers):
310+
self._validate_clustered_layer(layer, clustered_layer)
311+
230312
def testClusterModelUnsupportedKerasLayerRaisesError(self):
231313
"""
232314
Verifies that attempting to cluster a model that contains an unsupported

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self,
5757
layer,
5858
number_of_clusters,
5959
cluster_centroids_init,
60+
preserve_sparsity=False,
6061
**kwargs):
6162
if not isinstance(layer, Layer):
6263
raise ValueError(
@@ -90,9 +91,11 @@ def __init__(self,
9091
)
9192
)
9293

93-
if number_of_clusters <= 1:
94+
limit_number_of_clusters = 2 if preserve_sparsity else 1
95+
if number_of_clusters <= limit_number_of_clusters:
9496
raise ValueError(
95-
"number_of_clusters must be greater than 1. Given: {}".format(
97+
"number_of_clusters must be greater than {}. Given: {}".format(
98+
limit_number_of_clusters,
9699
number_of_clusters
97100
)
98101
)
@@ -105,6 +108,12 @@ def __init__(self,
105108
# The number of cluster centroids
106109
self.number_of_clusters = number_of_clusters
107110

111+
# Whether to apply sparsity preservation or not
112+
self.preserve_sparsity = preserve_sparsity
113+
114+
# Stores the pairs of weight names and their respective sparsity masks
115+
self.sparsity_masks = {}
116+
108117
# Stores the pairs of weight names and references to their tensors
109118
self.ori_weights_vars_tf = {}
110119

@@ -187,7 +196,7 @@ def build(self, input_shape):
187196
centroid_initializer = clustering_centroids.CentroidsInitializerFactory.\
188197
get_centroid_initializer(
189198
self.cluster_centroids_init
190-
)(weight, self.number_of_clusters)
199+
)(weight, self.number_of_clusters, self.preserve_sparsity)
191200

192201
cluster_centroids = centroid_initializer.get_cluster_centroids()
193202

@@ -229,6 +238,16 @@ def build(self, input_shape):
229238
)
230239
)
231240

241+
if self.preserve_sparsity:
242+
# Get the clustered weights
243+
clustered_weights = self.clustering_impl[weight_name].get_clustered_weight(pulling_indices)
244+
245+
# Create the sparsity mask
246+
sparsity_mask = tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32)
247+
248+
# Store the sparsity mask for training
249+
self.sparsity_masks[weight_name] = sparsity_mask
250+
232251
# We store these pairs to easily update this variables later on
233252
self.ori_weights_vars_tf[weight_name] = self.add_weight(
234253
'{}{}'.format('ori_weights_vars_tf_', weight_name),
@@ -241,13 +260,20 @@ def build(self, input_shape):
241260
)
242261

243262
# We use currying here to get an updater which can be triggered at any time
244-
# in future and it would return the latest version of clustered weights
263+
# in the future and it would return the latest version of clustered weights
245264
def get_updater(for_weight_name):
246265
def fn():
247266
# Get the clustered weights
248267
pulling_indices = self.pulling_indices_tf[for_weight_name]
249-
clustered_weights = self.clustering_impl[for_weight_name].\
250-
get_clustered_weight(pulling_indices)
268+
clustered_weights = self.clustering_impl[for_weight_name].get_clustered_weight(pulling_indices)
269+
270+
if self.preserve_sparsity:
271+
# Get the sparsity mask
272+
sparsity_mask = self.sparsity_masks[for_weight_name]
273+
274+
# Apply the sparsity mask to the clustered weights
275+
clustered_weights = tf.math.multiply(clustered_weights, sparsity_mask)
276+
251277
return clustered_weights
252278

253279
return fn
@@ -269,19 +295,17 @@ def call(self, inputs):
269295
# since they are integers and not differentiable. Gradients won't flow back
270296
# through tf.argmin
271297
# Go through all tensors and replace them with their clustered copies.
272-
for weight_name in self.ori_weights_vars_tf:
298+
for weight_name, _ in self.clustered_vars:
299+
# Get the clustered weights
273300
pulling_indices = self.pulling_indices_tf[weight_name]
301+
clustered_weights = self.clustering_impl[weight_name].get_clustered_weight(pulling_indices)
274302

275-
# Update cluster associations
276-
pulling_indices.assign(tf.dtypes.cast(
277-
self.clustering_impl[weight_name].\
278-
get_pulling_indices(self.ori_weights_vars_tf[weight_name]),
279-
pulling_indices.dtype
280-
))
303+
if self.preserve_sparsity:
304+
# Get the sparsity mask
305+
sparsity_mask = self.sparsity_masks[weight_name]
281306

282-
clustered_weights = self.clustering_impl[weight_name].\
283-
get_clustered_weight_forward(pulling_indices,\
284-
self.ori_weights_vars_tf[weight_name])
307+
# Apply the sparsity mask to the clustered weights
308+
clustered_weights = tf.math.multiply(clustered_weights, sparsity_mask)
285309

286310
# Replace the weights with their clustered counterparts
287311
setattr(self.layer, weight_name, clustered_weights)
@@ -295,7 +319,8 @@ def get_config(self):
295319
base_config = super(ClusterWeights, self).get_config()
296320
config = {
297321
'number_of_clusters': self.number_of_clusters,
298-
'cluster_centroids_init': self.cluster_centroids_init
322+
'cluster_centroids_init': self.cluster_centroids_init,
323+
'preserve_sparsity': self.preserve_sparsity
299324
}
300325
return dict(list(base_config.items()) + list(config.items()))
301326

@@ -305,9 +330,11 @@ def from_config(cls, config, custom_objects=None):
305330

306331
number_of_clusters = config.pop('number_of_clusters')
307332
cluster_centroids_init = config.pop('cluster_centroids_init')
333+
preserve_sparsity = config.pop('preserve_sparsity')
308334
config['number_of_clusters'] = number_of_clusters
309335
config['cluster_centroids_init'] = cluster_config.CentroidInitialization(
310336
cluster_centroids_init)
337+
config['preserve_sparsity'] = preserve_sparsity
311338

312339
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
313340
layer = deserialize_layer(config.pop('layer'),

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,25 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo(
131131
cluster_centroids_init=CentroidInitialization.LINEAR
132132
)
133133

134+
@parameterized.parameters(
135+
(0),
136+
(2),
137+
(-32)
138+
)
139+
def testCannotBeInitializedWithSparsityPreservationAndNumberOfClustersLessThanThree(
140+
self, number_of_clusters):
141+
"""
142+
Verifies that ClusterWeights cannot be initialized with less than three
143+
clusters when sparsity preservation is enabled.
144+
"""
145+
with self.assertRaises(ValueError):
146+
cluster_wrapper.ClusterWeights(
147+
layers.Dense(10),
148+
number_of_clusters=number_of_clusters,
149+
cluster_centroids_init=CentroidInitialization.LINEAR,
150+
preserve_sparsity=True
151+
)
152+
134153
def testCanBeInitializedWithAlreadyClusterableLayer(self):
135154
"""
136155
Verifies that ClusterWeights can be initialized with a custom clusterable

0 commit comments

Comments
 (0)