|
15 | 15 | """Tests for keras ClusterWeights wrapper API."""
|
16 | 16 |
|
17 | 17 | import itertools
|
| 18 | +import tempfile |
| 19 | +import os |
| 20 | +import tensorflow as tf |
18 | 21 |
|
19 | 22 | from absl.testing import parameterized
|
20 | 23 | import tensorflow as tf
|
@@ -230,63 +233,44 @@ def assert_all_weights_associated(weights, centroid_index):
|
230 | 233 | # Weights should now be all clustered with the centroid 1
|
231 | 234 | assert_all_weights_associated(l.layer.kernel, centroid_index=1)
|
232 | 235 |
|
233 |
| - def testClusterReassociation2(self): |
234 |
| - """Verifies that the association of weights to cluster centroids are updated every iteration.""" |
235 |
| - |
| 236 | + def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self): |
| 237 | + """Verify that the weights of a cluster_wrapper are the same |
| 238 | + before and after serialisation.""" |
236 | 239 | # Create a dummy layer for this test
|
237 | 240 | input_shape = (1, 2,)
|
238 |
| - l = cluster_wrapper.ClusterWeights( |
239 |
| - keras.layers.Dense(8, input_shape=input_shape), |
240 |
| - number_of_clusters=2, |
241 |
| - cluster_centroids_init=CentroidInitialization.LINEAR |
| 241 | + original_layer = cluster_wrapper.ClusterWeights( |
| 242 | + keras.layers.Dense(8, input_shape=input_shape), |
| 243 | + number_of_clusters=2, |
| 244 | + cluster_centroids_init=CentroidInitialization.LINEAR |
242 | 245 | )
|
243 | 246 | # Build a layer with the given shape
|
244 |
| - l.build(input_shape) |
245 |
| - |
246 |
| - # Get name of the clusterable weights |
247 |
| - clusterable_weights = l.layer.get_clusterable_weights() |
248 |
| - self.assertLen(clusterable_weights, 1) |
249 |
| - weights_name = clusterable_weights[0][0] |
250 |
| - self.assertEqual(weights_name, 'kernel') |
251 |
| - # Get cluster centroids |
252 |
| - centroids = l.cluster_centroids_tf[weights_name] |
253 |
| - |
254 |
| - # Calculate some statistics of the weights to set the centroids later on |
255 |
| - mean_weight = tf.reduce_mean(l.layer.kernel) |
256 |
| - min_weight = tf.reduce_min(l.layer.kernel) |
257 |
| - max_weight = tf.reduce_max(l.layer.kernel) |
258 |
| - max_dist = max_weight - min_weight |
259 |
| - |
260 |
| - def assert_all_weights_associated(weights, centroid_index): |
261 |
| - """Helper function to make sure that all weights are associated with one centroid.""" |
262 |
| - all_associated = tf.reduce_all( |
263 |
| - tf.equal( |
264 |
| - weights, |
265 |
| - tf.constant(centroids[centroid_index], shape=weights.shape) |
266 |
| - ) |
267 |
| - ) |
268 |
| - self.assertTrue(all_associated) |
269 |
| - |
270 |
| - # Set centroids so that all weights should be re-associated with centroid 0 |
271 |
| - centroids[0].assign(mean_weight) |
272 |
| - centroids[1].assign(mean_weight + 2.0 * max_dist) |
273 |
| - |
274 |
| - # Update associations of weights to centroids |
275 |
| - l.call(tf.ones(shape=input_shape)) |
276 |
| - |
277 |
| - # Weights should now be all clustered with the centroid 0 |
278 |
| - assert_all_weights_associated(l.layer.kernel, centroid_index=0) |
279 |
| - |
280 |
| - # Set centroids so that all weights should be re-associated with centroid 1 |
281 |
| - centroids[0].assign(mean_weight - 2.0 * max_dist) |
282 |
| - centroids[1].assign(mean_weight) |
283 |
| - |
284 |
| - # Update associations of weights to centroids |
285 |
| - l.call(tf.ones(shape=input_shape)) |
286 |
| - |
287 |
| - # Weights should now be all clustered with the centroid 1 |
288 |
| - assert_all_weights_associated(l.layer.kernel, centroid_index=1) |
289 |
| - |
| 247 | + original_layer.build(input_shape) |
| 248 | + |
| 249 | + # Save and load the layer in a temp directory |
| 250 | + with tempfile.TemporaryDirectory() as tmp_dir_name: |
| 251 | + keras_file = os.path.join(tmp_dir_name, 'keras_model') |
| 252 | + keras.models.save_model(original_layer, keras_file) |
| 253 | + with cluster.cluster_scope(): |
| 254 | + loaded_layer = keras.models.load_model(keras_file) |
| 255 | + |
| 256 | + def assertListOfVariablesAllEqual(l1, l2): |
| 257 | + assert len(l1) == len(l2), \ |
| 258 | + "lists l1 and l2 are not equal: \n l1={l1} \n l2={l2}".format( |
| 259 | + l1=[v.name for v in l1], |
| 260 | + l2=[v.name for v in l2]) |
| 261 | + |
| 262 | + name_to_var_from_l1 = {var.name: var for var in l1} |
| 263 | + for var2 in l2: |
| 264 | + arr1 = name_to_var_from_l1[var2.name].numpy() |
| 265 | + arr2 = var2.numpy() |
| 266 | + self.assertAllEqual(arr1, arr2) |
| 267 | + |
| 268 | + # Check that trainable_weights and non_trainable_weights are the same |
| 269 | + # in the original layer and loaded layer |
| 270 | + assertListOfVariablesAllEqual(original_layer.trainable_weights, |
| 271 | + loaded_layer.trainable_weights) |
| 272 | + assertListOfVariablesAllEqual(original_layer.non_trainable_weights, |
| 273 | + loaded_layer.non_trainable_weights) |
290 | 274 |
|
291 | 275 | if __name__ == '__main__':
|
292 | 276 | test.main()
|
0 commit comments