|
15 | 15 | """Tests for keras ClusterWeights wrapper API.""" |
16 | 16 |
|
17 | 17 | import itertools |
| 18 | +import os |
| 19 | +import tempfile |
18 | 20 |
|
19 | 21 | from absl.testing import parameterized |
20 | 22 | import tensorflow as tf |
@@ -230,63 +232,46 @@ def assert_all_weights_associated(weights, centroid_index): |
230 | 232 | # Weights should now be all clustered with the centroid 1 |
231 | 233 | assert_all_weights_associated(l.layer.kernel, centroid_index=1) |
232 | 234 |
|
233 | | - def testClusterReassociation2(self): |
234 | | - """Verifies that the association of weights to cluster centroids are updated every iteration.""" |
235 | | - |
| 235 | + def testSameWeightsAreReturnedBeforeAndAfterSerialisation(self): |
| 236 | + """Verify weights of cluster_wrapper are the same after serialisation.""" |
236 | 237 | # Create a dummy layer for this test |
237 | 238 | input_shape = (1, 2,) |
238 | | - l = cluster_wrapper.ClusterWeights( |
| 239 | + original_layer = cluster_wrapper.ClusterWeights( |
239 | 240 | keras.layers.Dense(8, input_shape=input_shape), |
240 | 241 | number_of_clusters=2, |
241 | 242 | cluster_centroids_init=CentroidInitialization.LINEAR |
242 | 243 | ) |
243 | 244 | # Build a layer with the given shape |
244 | | - l.build(input_shape) |
245 | | - |
246 | | - # Get name of the clusterable weights |
247 | | - clusterable_weights = l.layer.get_clusterable_weights() |
248 | | - self.assertLen(clusterable_weights, 1) |
249 | | - weights_name = clusterable_weights[0][0] |
250 | | - self.assertEqual(weights_name, 'kernel') |
251 | | - # Get cluster centroids |
252 | | - centroids = l.cluster_centroids_tf[weights_name] |
253 | | - |
254 | | - # Calculate some statistics of the weights to set the centroids later on |
255 | | - mean_weight = tf.reduce_mean(l.layer.kernel) |
256 | | - min_weight = tf.reduce_min(l.layer.kernel) |
257 | | - max_weight = tf.reduce_max(l.layer.kernel) |
258 | | - max_dist = max_weight - min_weight |
259 | | - |
260 | | - def assert_all_weights_associated(weights, centroid_index): |
261 | | - """Helper function to make sure that all weights are associated with one centroid.""" |
262 | | - all_associated = tf.reduce_all( |
263 | | - tf.equal( |
264 | | - weights, |
265 | | - tf.constant(centroids[centroid_index], shape=weights.shape) |
266 | | - ) |
267 | | - ) |
268 | | - self.assertTrue(all_associated) |
269 | | - |
270 | | - # Set centroids so that all weights should be re-associated with centroid 0 |
271 | | - centroids[0].assign(mean_weight) |
272 | | - centroids[1].assign(mean_weight + 2.0 * max_dist) |
273 | | - |
274 | | - # Update associations of weights to centroids |
275 | | - l.call(tf.ones(shape=input_shape)) |
276 | | - |
277 | | - # Weights should now be all clustered with the centroid 0 |
278 | | - assert_all_weights_associated(l.layer.kernel, centroid_index=0) |
279 | | - |
280 | | - # Set centroids so that all weights should be re-associated with centroid 1 |
281 | | - centroids[0].assign(mean_weight - 2.0 * max_dist) |
282 | | - centroids[1].assign(mean_weight) |
283 | | - |
284 | | - # Update associations of weights to centroids |
285 | | - l.call(tf.ones(shape=input_shape)) |
286 | | - |
287 | | - # Weights should now be all clustered with the centroid 1 |
288 | | - assert_all_weights_associated(l.layer.kernel, centroid_index=1) |
289 | | - |
| 245 | + original_layer.build(input_shape) |
| 246 | + model = keras.Sequential([original_layer]) |
| 247 | + |
| 248 | + # Save and load the layer in a temp directory |
| 249 | + with tempfile.TemporaryDirectory() as tmp_dir_name: |
| 250 | + keras_file = os.path.join(tmp_dir_name, 'keras_model') |
| 251 | + keras.models.save_model(model, keras_file) |
| 252 | + with cluster.cluster_scope(): |
| 253 | + loaded_layer = keras.models.load_model(keras_file).layers[0] |
| 254 | + |
| 255 | + def assert_list_of_variables_all_equal(l1, l2): |
| 256 | + self.assertLen( |
| 257 | + l1, len(l2), |
| 258 | + 'lists l1 and l2 are not equal: \n l1={l1} \n l2={l2}'.format( |
| 259 | + l1=[v.name for v in l1], |
| 260 | + l2=[v.name for v in l2])) |
| 261 | + |
| 262 | + name_to_var_from_l1 = {var.name: var for var in l1} |
| 263 | + for var2 in l2: |
| 264 | + self.assertIn(var2.name, name_to_var_from_l1) |
| 265 | + arr1 = name_to_var_from_l1[var2.name].numpy() |
| 266 | + arr2 = var2.numpy() |
| 267 | + self.assertAllEqual(arr1, arr2) |
| 268 | + |
| 269 | + # Check that trainable_weights and non_trainable_weights are the same |
| 270 | + # in the original layer and loaded layer |
| 271 | + assert_list_of_variables_all_equal(original_layer.trainable_weights, |
| 272 | + loaded_layer.trainable_weights) |
| 273 | + assert_list_of_variables_all_equal(original_layer.non_trainable_weights, |
| 274 | + loaded_layer.non_trainable_weights) |
290 | 275 |
|
291 | 276 | if __name__ == '__main__': |
292 | 277 | test.main() |
0 commit comments