Skip to content

Commit 92486bf

Browse files
Merge pull request #597 from johan-gras:up1
PiperOrigin-RevId: 368159207
2 parents eec0e4c + 4f9e194 commit 92486bf

File tree

1 file changed

+35
-50
lines changed

1 file changed

+35
-50
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for keras ClusterWeights wrapper API."""
1616

1717
import itertools
18+
import os
19+
import tempfile
1820

1921
from absl.testing import parameterized
2022
import tensorflow as tf
@@ -230,63 +232,46 @@ def assert_all_weights_associated(weights, centroid_index):
230232
# Weights should now be all clustered with the centroid 1
231233
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
232234

233-
def testClusterReassociation2(self):
234-
"""Verifies that the association of weights to cluster centroids are updated every iteration."""
235-
235+
def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self):
236+
"""Verify weights of cluster_wrapper are the same after serialisation."""
236237
# Create a dummy layer for this test
237238
input_shape = (1, 2,)
238-
l = cluster_wrapper.ClusterWeights(
239+
original_layer = cluster_wrapper.ClusterWeights(
239240
keras.layers.Dense(8, input_shape=input_shape),
240241
number_of_clusters=2,
241242
cluster_centroids_init=CentroidInitialization.LINEAR
242243
)
243244
# Build a layer with the given shape
244-
l.build(input_shape)
245-
246-
# Get name of the clusterable weights
247-
clusterable_weights = l.layer.get_clusterable_weights()
248-
self.assertLen(clusterable_weights, 1)
249-
weights_name = clusterable_weights[0][0]
250-
self.assertEqual(weights_name, 'kernel')
251-
# Get cluster centroids
252-
centroids = l.cluster_centroids_tf[weights_name]
253-
254-
# Calculate some statistics of the weights to set the centroids later on
255-
mean_weight = tf.reduce_mean(l.layer.kernel)
256-
min_weight = tf.reduce_min(l.layer.kernel)
257-
max_weight = tf.reduce_max(l.layer.kernel)
258-
max_dist = max_weight - min_weight
259-
260-
def assert_all_weights_associated(weights, centroid_index):
261-
"""Helper function to make sure that all weights are associated with one centroid."""
262-
all_associated = tf.reduce_all(
263-
tf.equal(
264-
weights,
265-
tf.constant(centroids[centroid_index], shape=weights.shape)
266-
)
267-
)
268-
self.assertTrue(all_associated)
269-
270-
# Set centroids so that all weights should be re-associated with centroid 0
271-
centroids[0].assign(mean_weight)
272-
centroids[1].assign(mean_weight + 2.0 * max_dist)
273-
274-
# Update associations of weights to centroids
275-
l.call(tf.ones(shape=input_shape))
276-
277-
# Weights should now be all clustered with the centroid 0
278-
assert_all_weights_associated(l.layer.kernel, centroid_index=0)
279-
280-
# Set centroids so that all weights should be re-associated with centroid 1
281-
centroids[0].assign(mean_weight - 2.0 * max_dist)
282-
centroids[1].assign(mean_weight)
283-
284-
# Update associations of weights to centroids
285-
l.call(tf.ones(shape=input_shape))
286-
287-
# Weights should now be all clustered with the centroid 1
288-
assert_all_weights_associated(l.layer.kernel, centroid_index=1)
289-
245+
original_layer.build(input_shape)
246+
model = keras.Sequential([original_layer])
247+
248+
# Save and load the layer in a temp directory
249+
with tempfile.TemporaryDirectory() as tmp_dir_name:
250+
keras_file = os.path.join(tmp_dir_name, 'keras_model')
251+
keras.models.save_model(model, keras_file)
252+
with cluster.cluster_scope():
253+
loaded_layer = keras.models.load_model(keras_file).layers[0]
254+
255+
def assert_list_of_variables_all_equal(l1, l2):
256+
self.assertLen(
257+
l1, len(l2),
258+
'lists l1 and l2 are not equal: \n l1={l1} \n l2={l2}'.format(
259+
l1=[v.name for v in l1],
260+
l2=[v.name for v in l2]))
261+
262+
name_to_var_from_l1 = {var.name: var for var in l1}
263+
for var2 in l2:
264+
self.assertIn(var2.name, name_to_var_from_l1)
265+
arr1 = name_to_var_from_l1[var2.name].numpy()
266+
arr2 = var2.numpy()
267+
self.assertAllEqual(arr1, arr2)
268+
269+
# Check that trainable_weights and non_trainable_weights are the same
270+
# in the original layer and loaded layer
271+
assert_list_of_variables_all_equal(original_layer.trainable_weights,
272+
loaded_layer.trainable_weights)
273+
assert_list_of_variables_all_equal(original_layer.non_trainable_weights,
274+
loaded_layer.non_trainable_weights)
290275

291276
if __name__ == '__main__':
292277
test.main()

0 commit comments

Comments
 (0)