Skip to content

Commit 4f9e194

Browse files
committed
Adding a unit test to check all needed weights are returned by cluster_wrapper.
Change-Id: I60b0c901b4d7b699cc44848481518df09d36a97a
1 parent e4a5200 commit 4f9e194

File tree

1 file changed

+37
-53
lines changed

1 file changed

+37
-53
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""Tests for keras ClusterWeights wrapper API."""
1616

1717
import itertools
18+
import tempfile
19+
import os
20+
import tensorflow as tf
1821

1922
from absl.testing import parameterized
2023
import tensorflow as tf
@@ -230,63 +233,44 @@ def assert_all_weights_associated(weights, centroid_index):
230233
# Weights should now be all clustered with the centroid 1
231234
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
232235

233-
def testClusterReassociation2(self):
234-
"""Verifies that the association of weights to cluster centroids are updated every iteration."""
235-
236+
def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self):
237+
"""Verify that the weights of a cluster_wrapper are the same
238+
before and after serialisation."""
236239
# Create a dummy layer for this test
237240
input_shape = (1, 2,)
238-
l = cluster_wrapper.ClusterWeights(
239-
keras.layers.Dense(8, input_shape=input_shape),
240-
number_of_clusters=2,
241-
cluster_centroids_init=CentroidInitialization.LINEAR
241+
original_layer = cluster_wrapper.ClusterWeights(
242+
keras.layers.Dense(8, input_shape=input_shape),
243+
number_of_clusters=2,
244+
cluster_centroids_init=CentroidInitialization.LINEAR
242245
)
243246
# Build a layer with the given shape
244-
l.build(input_shape)
245-
246-
# Get name of the clusterable weights
247-
clusterable_weights = l.layer.get_clusterable_weights()
248-
self.assertLen(clusterable_weights, 1)
249-
weights_name = clusterable_weights[0][0]
250-
self.assertEqual(weights_name, 'kernel')
251-
# Get cluster centroids
252-
centroids = l.cluster_centroids_tf[weights_name]
253-
254-
# Calculate some statistics of the weights to set the centroids later on
255-
mean_weight = tf.reduce_mean(l.layer.kernel)
256-
min_weight = tf.reduce_min(l.layer.kernel)
257-
max_weight = tf.reduce_max(l.layer.kernel)
258-
max_dist = max_weight - min_weight
259-
260-
def assert_all_weights_associated(weights, centroid_index):
261-
"""Helper function to make sure that all weights are associated with one centroid."""
262-
all_associated = tf.reduce_all(
263-
tf.equal(
264-
weights,
265-
tf.constant(centroids[centroid_index], shape=weights.shape)
266-
)
267-
)
268-
self.assertTrue(all_associated)
269-
270-
# Set centroids so that all weights should be re-associated with centroid 0
271-
centroids[0].assign(mean_weight)
272-
centroids[1].assign(mean_weight + 2.0 * max_dist)
273-
274-
# Update associations of weights to centroids
275-
l.call(tf.ones(shape=input_shape))
276-
277-
# Weights should now be all clustered with the centroid 0
278-
assert_all_weights_associated(l.layer.kernel, centroid_index=0)
279-
280-
# Set centroids so that all weights should be re-associated with centroid 1
281-
centroids[0].assign(mean_weight - 2.0 * max_dist)
282-
centroids[1].assign(mean_weight)
283-
284-
# Update associations of weights to centroids
285-
l.call(tf.ones(shape=input_shape))
286-
287-
# Weights should now be all clustered with the centroid 1
288-
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
289-
247+
original_layer.build(input_shape)
248+
249+
# Save and load the layer in a temp directory
250+
with tempfile.TemporaryDirectory() as tmp_dir_name:
251+
keras_file = os.path.join(tmp_dir_name, 'keras_model')
252+
keras.models.save_model(original_layer, keras_file)
253+
with cluster.cluster_scope():
254+
loaded_layer = keras.models.load_model(keras_file)
255+
256+
def assertListOfVariablesAllEqual(l1, l2):
257+
assert len(l1) == len(l2), \
258+
"lists l1 and l2 are not equal: \n l1={l1} \n l2={l2}".format(
259+
l1=[v.name for v in l1],
260+
l2=[v.name for v in l2])
261+
262+
name_to_var_from_l1 = {var.name: var for var in l1}
263+
for var2 in l2:
264+
arr1 = name_to_var_from_l1[var2.name].numpy()
265+
arr2 = var2.numpy()
266+
self.assertAllEqual(arr1, arr2)
267+
268+
# Check that trainable_weights and non_trainable_weights are the same
269+
# in the original layer and loaded layer
270+
assertListOfVariablesAllEqual(original_layer.trainable_weights,
271+
loaded_layer.trainable_weights)
272+
assertListOfVariablesAllEqual(original_layer.non_trainable_weights,
273+
loaded_layer.non_trainable_weights)
290274

291275
if __name__ == '__main__':
292276
test.main()

0 commit comments

Comments
 (0)