Skip to content

Commit 0f6dd5a

Browse files
Merge pull request #531 from Ruomei:toupstream/distribution_cluster_indices
PiperOrigin-RevId: 333337676
2 parents 270435d + 76f2d4e commit 0f6dd5a

File tree

5 files changed

+216
-2
lines changed

5 files changed

+216
-2
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,15 @@ py_test(
150150
"//tensorflow_model_optimization/python/core/keras:compat",
151151
],
152152
)
153+
154+
py_test(
155+
name = "cluster_distributed_test",
156+
srcs = ["cluster_distributed_test.py"],
157+
python_version = "PY3",
158+
visibility = ["//visibility:public"],
159+
deps = [
160+
":cluster",
161+
# tensorflow dep1,
162+
"//tensorflow_model_optimization/python/core/keras:test_utils",
163+
],
164+
)

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def cluster_scope():
4848
"""
4949
return CustomObjectScope(
5050
{
51-
'ClusterWeights' : cluster_wrapper.ClusterWeights
51+
'ClusterWeights': cluster_wrapper.ClusterWeights
5252
}
5353
)
5454

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Distributed clustering test."""
16+
17+
from absl.testing import parameterized
18+
import numpy as np
19+
import tensorflow as tf
20+
21+
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
22+
from tensorflow_model_optimization.python.core.clustering.keras import cluster
23+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
24+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
25+
26+
keras = tf.keras
27+
CentroidInitialization = cluster_config.CentroidInitialization
28+
29+
30+
def _distribution_strategies():
31+
return [
32+
tf.distribute.MirroredStrategy()
33+
]
34+
35+
36+
class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
37+
"""Distributed tests for clustering."""
38+
39+
def setUp(self):
40+
super(ClusterDistributedTest, self).setUp()
41+
self.params = {
42+
"number_of_clusters": 2,
43+
"cluster_centroids_init": CentroidInitialization.LINEAR
44+
}
45+
46+
47+
@parameterized.parameters(_distribution_strategies())
48+
def testClusterSimpleDenseModel(self, distribution):
49+
"""End-to-end test."""
50+
with distribution.scope():
51+
model = cluster.cluster_weights(
52+
keras_test_utils.build_simple_dense_model(), **self.params)
53+
model.compile(
54+
loss='categorical_crossentropy',
55+
optimizer='sgd',
56+
metrics=['accuracy'])
57+
58+
model.summary()
59+
model.fit(
60+
np.random.rand(20, 10),
61+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
62+
epochs=1,
63+
batch_size=20)
64+
model.predict(np.random.rand(20, 10))
65+
66+
stripped_model = cluster.strip_clustering(model)
67+
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
68+
unique_weights = set(weights_as_list)
69+
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
70+
71+
@parameterized.parameters(_distribution_strategies())
72+
def testAssociationValuesPerReplica(self, distribution):
73+
"""Verifies that associations of weights are updated per replica."""
74+
assert tf.distribute.get_replica_context() is not None
75+
with distribution.scope():
76+
assert tf.distribute.get_replica_context() is None
77+
input_shape = (1, 2)
78+
output_shape = (2, 8)
79+
l = cluster_wrapper.ClusterWeights(
80+
keras.layers.Dense(8, input_shape=input_shape),
81+
number_of_clusters=self.params["number_of_clusters"],
82+
cluster_centroids_init=self.params["cluster_centroids_init"]
83+
)
84+
l.build(input_shape)
85+
86+
clusterable_weights = l.layer.get_clusterable_weights()
87+
self.assertEqual(len(clusterable_weights), 1)
88+
weights_name = clusterable_weights[0][0]
89+
self.assertEqual(weights_name, 'kernel')
90+
centroids1 = l.cluster_centroids_tf[weights_name]
91+
92+
mean_weight = tf.reduce_mean(l.layer.kernel)
93+
min_weight = tf.reduce_min(l.layer.kernel)
94+
max_weight = tf.reduce_max(l.layer.kernel)
95+
max_dist = max_weight - min_weight
96+
97+
def assert_all_cluster_indices(per_replica, indices_val):
98+
if indices_val == 1:
99+
val_tensor = tf.dtypes.cast(
100+
tf.ones(shape=output_shape), per_replica[0].dtype)
101+
if indices_val == 0:
102+
val_tensor = tf.dtypes.cast(
103+
tf.zeros(shape=output_shape), per_replica[0].dtype)
104+
for i in range(0, len(per_replica)):
105+
all_equal = tf.reduce_all(
106+
tf.equal(
107+
per_replica[i], val_tensor
108+
)
109+
)
110+
self.assertTrue(all_equal)
111+
112+
def update_fn(v, val):
113+
return v.assign(val)
114+
115+
initial_val = tf.Variable([mean_weight, mean_weight + 2.0 * max_dist], \
116+
aggregation=tf.VariableAggregation.MEAN)
117+
118+
centroids1 = distribution.extended.update(
119+
centroids1, update_fn, args=(initial_val,))
120+
l.call(tf.ones(shape=input_shape))
121+
122+
clst_indices = l.pulling_indices_tf[weights_name]
123+
per_replica = distribution.experimental_local_results(clst_indices)
124+
assert_all_cluster_indices(per_replica, 0)
125+
126+
second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
127+
aggregation=tf.VariableAggregation.MEAN)
128+
centroids2 = l.cluster_centroids_tf[weights_name]
129+
centroids2 = distribution.extended.update(
130+
centroids2, update_fn, args=(second_val,))
131+
l.call(tf.ones(shape=input_shape))
132+
133+
clst_indices = l.pulling_indices_tf[weights_name]
134+
per_replica = distribution.experimental_local_results(clst_indices)
135+
assert_all_cluster_indices(per_replica, 1)
136+
137+
if __name__ == '__main__':
138+
tf.test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def build(self, input_shape):
222222
shape=pulling_indices.shape,
223223
dtype=tf.int32,
224224
trainable=False,
225+
synchronization=tf.VariableSynchronization.ON_READ,
226+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
225227
initializer=initializers.Constant(
226228
value=k.batch_get_value([pulling_indices])[0]
227229
)
@@ -254,7 +256,7 @@ def fn():
254256
# This loop stores pairs of weight names and how to restore them
255257
for ct, weight in enumerate(self.layer.weights):
256258
name = self._weight_name(weight.name)
257-
full_name = self.layer.name + "/" + name
259+
full_name = '{}/{}'.format(self.layer.name, name)
258260
if ct in self.gone_variables:
259261
# Again, not sure if this is needed
260262
weight_name = clusterable_weights_to_variables[name]

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,5 +254,67 @@ def assert_all_weights_associated(weights, centroid_index):
254254
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
255255

256256

257+
def testClusterReassociation(self):
258+
"""
259+
Verifies that the association of weights to cluster centroids are updated
260+
every iteration.
261+
"""
262+
263+
# Create a dummy layer for this test
264+
input_shape = (1, 2,)
265+
l = cluster_wrapper.ClusterWeights(
266+
keras.layers.Dense(8, input_shape=input_shape),
267+
number_of_clusters=2,
268+
cluster_centroids_init=CentroidInitialization.LINEAR
269+
)
270+
# Build a layer with the given shape
271+
l.build(input_shape)
272+
273+
# Get name of the clusterable weights
274+
clusterable_weights = l.layer.get_clusterable_weights()
275+
self.assertEqual(len(clusterable_weights), 1)
276+
weights_name = clusterable_weights[0][0]
277+
self.assertEqual(weights_name, 'kernel')
278+
# Get cluster centroids
279+
centroids = l.cluster_centroids_tf[weights_name]
280+
281+
# Calculate some statistics of the weights to set the centroids later on
282+
mean_weight = tf.reduce_mean(l.layer.kernel)
283+
min_weight = tf.reduce_min(l.layer.kernel)
284+
max_weight = tf.reduce_max(l.layer.kernel)
285+
max_dist = max_weight - min_weight
286+
287+
def assert_all_weights_associated(weights, centroid_index):
288+
"""Helper function to make sure that all weights are associated with one
289+
centroid."""
290+
all_associated = tf.reduce_all(
291+
tf.equal(
292+
weights,
293+
tf.constant(centroids[centroid_index], shape=weights.shape)
294+
)
295+
)
296+
self.assertTrue(all_associated)
297+
298+
# Set centroids so that all weights should be re-associated with centroid 0
299+
centroids[0].assign(mean_weight)
300+
centroids[1].assign(mean_weight + 2.0 * max_dist)
301+
302+
# Update associations of weights to centroids
303+
l.call(tf.ones(shape=input_shape))
304+
305+
# Weights should now be all clustered with the centroid 0
306+
assert_all_weights_associated(l.layer.kernel, centroid_index=0)
307+
308+
# Set centroids so that all weights should be re-associated with centroid 1
309+
centroids[0].assign(mean_weight - 2.0 * max_dist)
310+
centroids[1].assign(mean_weight)
311+
312+
# Update associations of weights to centroids
313+
l.call(tf.ones(shape=input_shape))
314+
315+
# Weights should now be all clustered with the centroid 1
316+
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
317+
318+
257319
if __name__ == '__main__':
258320
test.main()

0 commit comments

Comments
 (0)