1414# ==============================================================================
1515"""Distributed clustering test."""
1616
17- from absl .testing import parameterized
17+ import itertools
18+ import unittest
19+
1820import numpy as np
1921import tensorflow as tf
20-
21- from tensorflow_model_optimization .python .core .clustering .keras import cluster
22- from tensorflow_model_optimization .python .core .clustering .keras import cluster_config
23- from tensorflow_model_optimization .python .core .clustering .keras import cluster_wrapper
24- from tensorflow_model_optimization .python .core .keras import test_utils as keras_test_utils
22+ from absl .testing import parameterized
23+ from tensorflow_model_optimization .python .core .clustering .keras import (
24+ cluster , cluster_config , cluster_wrapper )
25+ from tensorflow_model_optimization .python .core .clustering .keras .experimental import \
26+ cluster as experimental_cluster
27+ from tensorflow_model_optimization .python .core .keras import \
28+ test_utils as keras_test_utils
2529
2630keras = tf .keras
2731CentroidInitialization = cluster_config .CentroidInitialization
3034def _distribution_strategies ():
3135 return [tf .distribute .MirroredStrategy ()]
3236
37+ def _clustering_strategies ():
38+ return [
39+ {
40+ 'number_of_clusters' : 2 ,
41+ 'cluster_centroids_init' : CentroidInitialization .LINEAR ,
42+ 'preserve_sparsity' : False
43+ },
44+ {
45+ 'number_of_clusters' : 3 ,
46+ 'cluster_centroids_init' : CentroidInitialization .KMEANS_PLUS_PLUS ,
47+ 'preserve_sparsity' : True
48+ }
49+ ]
3350
3451class ClusterDistributedTest (tf .test .TestCase , parameterized .TestCase ):
3552 """Distributed tests for clustering."""
3653
3754 def setUp (self ):
3855 super (ClusterDistributedTest , self ).setUp ()
39- self .params = {
40- 'number_of_clusters' : 2 ,
41- 'cluster_centroids_init' : CentroidInitialization .LINEAR
42- }
4356
44- @parameterized .parameters (_distribution_strategies ())
45- def testClusterSimpleDenseModel (self , distribution ):
57+ @parameterized .parameters (
58+ * itertools .product (
59+ _distribution_strategies (),
60+ _clustering_strategies ()
61+ )
62+ )
63+ def testClusterSimpleDenseModel (self , distribution , clustering ):
4664 """End-to-end test."""
4765 with distribution .scope ():
48- model = cluster .cluster_weights (
49- keras_test_utils .build_simple_dense_model (), ** self . params )
66+ model = experimental_cluster .cluster_weights (
67+ keras_test_utils .build_simple_dense_model (), ** clustering )
5068 model .compile (
5169 loss = 'categorical_crossentropy' ,
5270 optimizer = 'sgd' ,
@@ -64,9 +82,11 @@ def testClusterSimpleDenseModel(self, distribution):
6482 weights_as_list = stripped_model .layers [0 ].kernel .numpy ().reshape (
6583 - 1 ,).tolist ()
6684 unique_weights = set (weights_as_list )
67- self .assertLessEqual (len (unique_weights ), self . params [ ' number_of_clusters' ])
85+ self .assertLessEqual (len (unique_weights ), clustering [ " number_of_clusters" ])
6886
69- @parameterized .parameters (_distribution_strategies ())
87+ @parameterized .parameters (
88+ _distribution_strategies ()
89+ )
7090 def testAssociationValuesPerReplica (self , distribution ):
7191 """Verifies that associations of weights are updated per replica."""
7292 assert tf .distribute .get_replica_context () is not None
@@ -76,8 +96,9 @@ def testAssociationValuesPerReplica(self, distribution):
7696 output_shape = (2 , 8 )
7797 l = cluster_wrapper .ClusterWeights (
7898 keras .layers .Dense (8 , input_shape = input_shape ),
79- number_of_clusters = self .params ['number_of_clusters' ],
80- cluster_centroids_init = self .params ['cluster_centroids_init' ])
99+ number_of_clusters = 2 ,
100+ cluster_centroids_init = CentroidInitialization .LINEAR
101+ )
81102 l .build (input_shape )
82103
83104 clusterable_weights = l .layer .get_clusterable_weights ()
0 commit comments