Skip to content

Commit 9381edc

Browse files
teijeongtensorflower-gardener
authored andcommitted
Fix formatting for clustering codes
* Ran automated formatting fix * Eliminated pylint errors PiperOrigin-RevId: 370812580
1 parent c59c33b commit 9381edc

File tree

11 files changed

+441
-568
lines changed

11 files changed

+441
-568
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,10 @@ def _strip_clustering_wrapper(layer):
299299
layer.update_clustered_weights_associations()
300300

301301
# Construct a list of weights to initialize the clean layer
302-
updated_weights = layer.layer.get_weights() # non clusterable weights only
303-
for position_variable, weight_name in layer.position_original_weights.items():
302+
# non-clusterable weights only
303+
updated_weights = layer.layer.get_weights()
304+
for (position_variable,
305+
weight_name) in layer.position_original_weights.items():
304306
# Add the clustered weights at the correct position
305307
clustered_weight = getattr(layer.layer, weight_name)
306308
updated_weights.insert(position_variable, clustered_weight)

tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,17 @@
1818
import numpy as np
1919
import tensorflow as tf
2020

21-
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2221
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2322
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2423
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
24+
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2525

2626
keras = tf.keras
2727
CentroidInitialization = cluster_config.CentroidInitialization
2828

2929

3030
def _distribution_strategies():
31-
return [
32-
tf.distribute.MirroredStrategy()
33-
]
31+
return [tf.distribute.MirroredStrategy()]
3432

3533

3634
class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
@@ -39,8 +37,8 @@ class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
3937
def setUp(self):
4038
super(ClusterDistributedTest, self).setUp()
4139
self.params = {
42-
"number_of_clusters": 2,
43-
"cluster_centroids_init": CentroidInitialization.LINEAR
40+
'number_of_clusters': 2,
41+
'cluster_centroids_init': CentroidInitialization.LINEAR
4442
}
4543

4644
@parameterized.parameters(_distribution_strategies())
@@ -63,9 +61,10 @@ def testClusterSimpleDenseModel(self, distribution):
6361
model.predict(np.random.rand(20, 10))
6462

6563
stripped_model = cluster.strip_clustering(model)
66-
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(-1,).tolist()
64+
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(
65+
-1,).tolist()
6766
unique_weights = set(weights_as_list)
68-
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
67+
self.assertLessEqual(len(unique_weights), self.params['number_of_clusters'])
6968

7069
@parameterized.parameters(_distribution_strategies())
7170
def testAssociationValuesPerReplica(self, distribution):
@@ -77,13 +76,12 @@ def testAssociationValuesPerReplica(self, distribution):
7776
output_shape = (2, 8)
7877
l = cluster_wrapper.ClusterWeights(
7978
keras.layers.Dense(8, input_shape=input_shape),
80-
number_of_clusters=self.params["number_of_clusters"],
81-
cluster_centroids_init=self.params["cluster_centroids_init"]
82-
)
79+
number_of_clusters=self.params['number_of_clusters'],
80+
cluster_centroids_init=self.params['cluster_centroids_init'])
8381
l.build(input_shape)
8482

8583
clusterable_weights = l.layer.get_clusterable_weights()
86-
self.assertEqual(len(clusterable_weights), 1)
84+
self.assertLen(clusterable_weights, 1)
8785
weights_name = clusterable_weights[0][0]
8886
self.assertEqual(weights_name, 'kernel')
8987
centroids1 = l.cluster_centroids[weights_name]
@@ -101,18 +99,14 @@ def assert_all_cluster_indices(per_replica, indices_val):
10199
val_tensor = tf.dtypes.cast(
102100
tf.zeros(shape=output_shape), per_replica[0].dtype)
103101
for i in range(0, len(per_replica)):
104-
all_equal = tf.reduce_all(
105-
tf.equal(
106-
per_replica[i], val_tensor
107-
)
108-
)
102+
all_equal = tf.reduce_all(tf.equal(per_replica[i], val_tensor))
109103
self.assertTrue(all_equal)
110104

111105
def update_fn(v, val):
112106
return v.assign(val)
113107

114-
initial_val = tf.Variable([mean_weight, mean_weight + 2.0 * max_dist], \
115-
aggregation=tf.VariableAggregation.MEAN)
108+
initial_val = tf.Variable([mean_weight, mean_weight + 2.0 * max_dist],
109+
aggregation=tf.VariableAggregation.MEAN)
116110

117111
centroids1 = distribution.extended.update(
118112
centroids1, update_fn, args=(initial_val,))
@@ -122,8 +116,8 @@ def update_fn(v, val):
122116
per_replica = distribution.experimental_local_results(clst_indices)
123117
assert_all_cluster_indices(per_replica, 0)
124118

125-
second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
126-
aggregation=tf.VariableAggregation.MEAN)
119+
second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight],
120+
aggregation=tf.VariableAggregation.MEAN)
127121
centroids2 = l.cluster_centroids[weights_name]
128122
centroids2 = distribution.extended.update(
129123
centroids2, update_fn, args=(second_val,))
@@ -133,5 +127,6 @@ def update_fn(v, val):
133127
per_replica = distribution.experimental_local_results(clst_indices)
134128
assert_all_cluster_indices(per_replica, 1)
135129

130+
136131
if __name__ == '__main__':
137132
tf.test.main()

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def testValuesRemainClusteredAfterTraining(self):
152152

153153
@keras_parameterized.run_all_keras_modes
154154
def testSparsityIsPreservedDuringTraining(self):
155-
"""Set a specific random seed to ensure that we get some null weights to test sparsity preservation with."""
155+
# Set a specific random seed to ensure that we get some null weights to
156+
# test sparsity preservation with.
156157
tf.random.set_seed(1)
157158

158159
# Verifies that training a clustered model does not destroy the sparsity of
@@ -187,7 +188,8 @@ def testSparsityIsPreservedDuringTraining(self):
187188
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
188189
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
189190
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
190-
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(-1,).tolist()
191+
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(
192+
-1,).tolist()
191193
unique_weights_after_tuning = set(weights_as_list_after_tuning)
192194

193195
# Check that the null weights stayed the same before and after tuning.
@@ -245,15 +247,15 @@ def testEndToEndDeepLayer(self):
245247

246248
def clusters_check(stripped_model):
247249
# inner dense layer
248-
weights_as_list = stripped_model.submodules[1].trainable_weights[0].\
249-
numpy().flatten()
250+
weights_as_list = (
251+
stripped_model.submodules[1].trainable_weights[0].numpy().flatten())
250252
unique_weights = set(weights_as_list)
251253
self.assertLessEqual(
252254
len(unique_weights), self.params["number_of_clusters"])
253255

254256
# outer dense layer
255-
weights_as_list = stripped_model.submodules[4].trainable_weights[0].\
256-
numpy().flatten()
257+
weights_as_list = (
258+
stripped_model.submodules[4].trainable_weights[0].numpy().flatten())
257259
unique_weights = set(weights_as_list)
258260
self.assertLessEqual(
259261
len(unique_weights), self.params["number_of_clusters"])
@@ -276,23 +278,22 @@ def testEndToEndDeepLayer2(self):
276278

277279
def clusters_check(stripped_model):
278280
# first inner dense layer
279-
weights_as_list = stripped_model.submodules[1].trainable_weights[0].\
280-
numpy().flatten()
281+
weights_as_list = (
282+
stripped_model.submodules[1].trainable_weights[0].numpy().flatten())
281283
unique_weights = set(weights_as_list)
282284
self.assertLessEqual(
283285
len(unique_weights), self.params["number_of_clusters"])
284286

285287
# second inner dense layer
286-
weights_as_list = stripped_model.submodules[4].\
287-
trainable_weights[0].\
288-
numpy().flatten()
288+
weights_as_list = (
289+
stripped_model.submodules[4].trainable_weights[0].numpy().flatten())
289290
unique_weights = set(weights_as_list)
290291
self.assertLessEqual(
291292
len(unique_weights), self.params["number_of_clusters"])
292293

293294
# outer dense layer
294-
weights_as_list = stripped_model.submodules[7].trainable_weights[0].\
295-
numpy().flatten()
295+
weights_as_list = (
296+
stripped_model.submodules[7].trainable_weights[0].numpy().flatten())
296297
unique_weights = set(weights_as_list)
297298
self.assertLessEqual(
298299
len(unique_weights), self.params["number_of_clusters"])
@@ -301,51 +302,46 @@ def clusters_check(stripped_model):
301302

302303
@keras_parameterized.run_all_keras_modes
303304
def testWeightsAreLearningDuringClustering(self):
304-
"""Verifies that training a clustered model does update
305-
original_weights, clustered_centroids and bias."""
306-
original_model = keras.Sequential([
307-
layers.Dense(5, input_shape=(5,))
308-
])
305+
"""Verifies that weights are updated during training a clustered model.
306+
307+
Training a clustered model should update original_weights,
308+
clustered_centroids and bias.
309+
"""
310+
original_model = keras.Sequential([layers.Dense(5, input_shape=(5,))])
309311

310312
clustered_model = cluster.cluster_weights(original_model, **self.params)
311313

312314
clustered_model.compile(
313-
loss=keras.losses.categorical_crossentropy,
314-
optimizer="adam",
315-
metrics=["accuracy"],
315+
loss=keras.losses.categorical_crossentropy,
316+
optimizer="adam",
317+
metrics=["accuracy"],
316318
)
317319

318320
class CheckWeightsCallback(keras.callbacks.Callback):
321+
319322
def on_train_batch_begin(self, batch, logs=None):
320323
# Save weights before batch
321324
self.original_weight_kernel = (
322-
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
323-
)
325+
self.model.layers[0].original_clusterable_weights["kernel"].numpy())
324326
self.cluster_centroids_kernel = (
325-
self.model.layers[0].cluster_centroids['kernel'].numpy()
326-
)
327-
self.bias = (
328-
self.model.layers[0].layer.bias.numpy()
329-
)
327+
self.model.layers[0].cluster_centroids["kernel"].numpy())
328+
self.bias = (self.model.layers[0].layer.bias.numpy())
330329

331330
def on_train_batch_end(self, batch, logs=None):
332331
# Check weights are different after batch
333332
assert not np.array_equal(
334-
self.original_weight_kernel,
335-
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
336-
)
333+
self.original_weight_kernel,
334+
self.model.layers[0].original_clusterable_weights["kernel"].numpy())
337335
assert not np.array_equal(
338-
self.cluster_centroids_kernel,
339-
self.model.layers[0].cluster_centroids['kernel'].numpy()
340-
)
341-
assert not np.array_equal(
342-
self.bias,
343-
self.model.layers[0].layer.bias.numpy()
344-
)
345-
346-
clustered_model.fit(x=self.dataset_generator(),
347-
steps_per_epoch=5,
348-
callbacks=[CheckWeightsCallback()])
336+
self.cluster_centroids_kernel,
337+
self.model.layers[0].cluster_centroids["kernel"].numpy())
338+
assert not np.array_equal(self.bias,
339+
self.model.layers[0].layer.bias.numpy())
340+
341+
clustered_model.fit(
342+
x=self.dataset_generator(),
343+
steps_per_epoch=5,
344+
callbacks=[CheckWeightsCallback()])
349345

350346

351347
if __name__ == "__main__":

0 commit comments

Comments
 (0)