Skip to content

Commit a12e20e

Browse files
committed
Add Anti Zero-Drift functionality for Sparsity-Aware clustering (experimental)
* Set the random seed in the sparsity preservation test to a specific value to make sure that some of the weights are null
1 parent 864de3c commit a12e20e

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def testValuesRemainClusteredAfterTraining(self):
148148

149149
@keras_parameterized.run_all_keras_modes
150150
def testSparsityIsPreservedDuringTraining(self):
151+
""" Set a specific random seed to ensure that we get some null weights to test sparsity preservation with. """
152+
tf.random.set_seed(1)
153+
151154
"""Verifies that training a clustered model does not destroy the sparsity of the weights."""
152155
original_model = keras.Sequential([
153156
layers.Dense(5, input_shape=(5,)),

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,8 @@ def get_updater(for_weight_name):
265265
def fn():
266266
# Get the clustered weights
267267
pulling_indices = self.pulling_indices_tf[for_weight_name]
268-
clustered_weights = self.clustering_impl[for_weight_name].get_clustered_weight(pulling_indices)
268+
clustered_weights = self.clustering_impl[for_weight_name].\
269+
get_clustered_weight(pulling_indices)
269270

270271
if self.preserve_sparsity:
271272
# Get the sparsity mask
@@ -295,10 +296,20 @@ def call(self, inputs):
295296
# since they are integers and not differentiable. Gradients won't flow back
296297
# through tf.argmin
297298
# Go through all tensors and replace them with their clustered copies.
298-
for weight_name, _ in self.clustered_vars:
299-
# Get the clustered weights
299+
for weight_name in self.ori_weights_vars_tf:
300300
pulling_indices = self.pulling_indices_tf[weight_name]
301-
clustered_weights = self.clustering_impl[weight_name].get_clustered_weight(pulling_indices)
301+
302+
# Update cluster associations
303+
pulling_indices.assign(tf.dtypes.cast(
304+
self.clustering_impl[weight_name].\
305+
get_pulling_indices(self.ori_weights_vars_tf[weight_name]),
306+
pulling_indices.dtype
307+
))
308+
309+
# Get the clustered weights
310+
clustered_weights = self.clustering_impl[weight_name].\
311+
get_clustered_weight_forward(pulling_indices,\
312+
self.ori_weights_vars_tf[weight_name])
302313

303314
if self.preserve_sparsity:
304315
# Get the sparsity mask

0 commit comments

Comments
 (0)