Skip to content

Commit 28b68e3

Browse files
committed
Fix for sparsity-preserving clustering
* small fix to stop centroids from drifting * unit test coverage for fix
1 parent d942a15 commit 28b68e3

File tree

3 files changed

+44
-18
lines changed

3 files changed

+44
-18
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ py_strict_test(
202202
# numpy dep1,
203203
# tensorflow dep1,
204204
"//tensorflow_model_optimization/python/core/keras:test_utils",
205+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
205206
],
206207
)
207208

tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414
# ==============================================================================
1515
"""Distributed clustering test."""
1616

17-
from absl.testing import parameterized
17+
import itertools
18+
import unittest
19+
1820
import numpy as np
1921
import tensorflow as tf
20-
21-
from tensorflow_model_optimization.python.core.clustering.keras import cluster
22-
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
23-
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
24-
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
22+
from absl.testing import parameterized
23+
from tensorflow_model_optimization.python.core.clustering.keras import (
24+
cluster, cluster_config, cluster_wrapper)
25+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import \
26+
cluster as experimental_cluster
27+
from tensorflow_model_optimization.python.core.keras import \
28+
test_utils as keras_test_utils
2529

2630
keras = tf.keras
2731
CentroidInitialization = cluster_config.CentroidInitialization
@@ -30,23 +34,37 @@
3034
def _distribution_strategies():
3135
return [tf.distribute.MirroredStrategy()]
3236

37+
def _clustering_strategies():
38+
return [
39+
{
40+
'number_of_clusters': 2,
41+
'cluster_centroids_init': CentroidInitialization.LINEAR,
42+
'preserve_sparsity': False
43+
},
44+
{
45+
'number_of_clusters': 3,
46+
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
47+
'preserve_sparsity': True
48+
}
49+
]
3350

3451
class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
3552
"""Distributed tests for clustering."""
3653

3754
def setUp(self):
3855
super(ClusterDistributedTest, self).setUp()
39-
self.params = {
40-
'number_of_clusters': 2,
41-
'cluster_centroids_init': CentroidInitialization.LINEAR
42-
}
4356

44-
@parameterized.parameters(_distribution_strategies())
45-
def testClusterSimpleDenseModel(self, distribution):
57+
@parameterized.parameters(
58+
*itertools.product(
59+
_distribution_strategies(),
60+
_clustering_strategies()
61+
)
62+
)
63+
def testClusterSimpleDenseModel(self, distribution, clustering):
4664
"""End-to-end test."""
4765
with distribution.scope():
48-
model = cluster.cluster_weights(
49-
keras_test_utils.build_simple_dense_model(), **self.params)
66+
model = experimental_cluster.cluster_weights(
67+
keras_test_utils.build_simple_dense_model(), **clustering)
5068
model.compile(
5169
loss='categorical_crossentropy',
5270
optimizer='sgd',
@@ -64,9 +82,11 @@ def testClusterSimpleDenseModel(self, distribution):
6482
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(
6583
-1,).tolist()
6684
unique_weights = set(weights_as_list)
67-
self.assertLessEqual(len(unique_weights), self.params['number_of_clusters'])
85+
self.assertLessEqual(len(unique_weights), clustering["number_of_clusters"])
6886

69-
@parameterized.parameters(_distribution_strategies())
87+
@parameterized.parameters(
88+
_distribution_strategies()
89+
)
7090
def testAssociationValuesPerReplica(self, distribution):
7191
"""Verifies that associations of weights are updated per replica."""
7292
assert tf.distribute.get_replica_context() is not None
@@ -76,8 +96,9 @@ def testAssociationValuesPerReplica(self, distribution):
7696
output_shape = (2, 8)
7797
l = cluster_wrapper.ClusterWeights(
7898
keras.layers.Dense(8, input_shape=input_shape),
79-
number_of_clusters=self.params['number_of_clusters'],
80-
cluster_centroids_init=self.params['cluster_centroids_init'])
99+
number_of_clusters=2,
100+
cluster_centroids_init=CentroidInitialization.LINEAR
101+
)
81102
l.build(input_shape)
82103

83104
clusterable_weights = l.layer.get_clusterable_weights()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ def update_clustered_weights_associations(self):
215215
pulling_indices, original_weight))
216216

217217
if self.preserve_sparsity:
218+
# Re-discover the sparsity masks to avoid drifting
219+
self.sparsity_masks[weight_name] = (
220+
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32)
221+
)
218222
# Apply the sparsity mask to the clustered weights
219223
clustered_weights = tf.math.multiply(clustered_weights,
220224
self.sparsity_masks[weight_name])

0 commit comments

Comments
 (0)