14
14
# ==============================================================================
15
15
"""Distributed clustering test."""
16
16
17
- from absl .testing import parameterized
17
+ import itertools
18
+ import unittest
19
+
18
20
import numpy as np
19
21
import tensorflow as tf
20
-
21
- from tensorflow_model_optimization .python .core .clustering .keras import cluster
22
- from tensorflow_model_optimization .python .core .clustering .keras import cluster_config
23
- from tensorflow_model_optimization .python .core .clustering .keras import cluster_wrapper
24
- from tensorflow_model_optimization .python .core .keras import test_utils as keras_test_utils
22
+ from absl .testing import parameterized
23
+ from tensorflow_model_optimization .python .core .clustering .keras import (
24
+ cluster , cluster_config , cluster_wrapper )
25
+ from tensorflow_model_optimization .python .core .clustering .keras .experimental import \
26
+ cluster as experimental_cluster
27
+ from tensorflow_model_optimization .python .core .keras import \
28
+ test_utils as keras_test_utils
25
29
26
30
keras = tf .keras
27
31
CentroidInitialization = cluster_config .CentroidInitialization
30
34
def _distribution_strategies ():
31
35
return [tf .distribute .MirroredStrategy ()]
32
36
37
+ def _clustering_strategies ():
38
+ return [
39
+ {
40
+ 'number_of_clusters' : 2 ,
41
+ 'cluster_centroids_init' : CentroidInitialization .LINEAR ,
42
+ 'preserve_sparsity' : False
43
+ },
44
+ {
45
+ 'number_of_clusters' : 3 ,
46
+ 'cluster_centroids_init' : CentroidInitialization .KMEANS_PLUS_PLUS ,
47
+ 'preserve_sparsity' : True
48
+ }
49
+ ]
33
50
34
51
class ClusterDistributedTest (tf .test .TestCase , parameterized .TestCase ):
35
52
"""Distributed tests for clustering."""
36
53
37
54
def setUp (self ):
38
55
super (ClusterDistributedTest , self ).setUp ()
39
- self .params = {
40
- 'number_of_clusters' : 2 ,
41
- 'cluster_centroids_init' : CentroidInitialization .LINEAR
42
- }
43
56
44
- @parameterized .parameters (_distribution_strategies ())
45
- def testClusterSimpleDenseModel (self , distribution ):
57
+ @parameterized .parameters (
58
+ * itertools .product (
59
+ _distribution_strategies (),
60
+ _clustering_strategies ()
61
+ )
62
+ )
63
+ def testClusterSimpleDenseModel (self , distribution , clustering ):
46
64
"""End-to-end test."""
47
65
with distribution .scope ():
48
- model = cluster .cluster_weights (
49
- keras_test_utils .build_simple_dense_model (), ** self . params )
66
+ model = experimental_cluster .cluster_weights (
67
+ keras_test_utils .build_simple_dense_model (), ** clustering )
50
68
model .compile (
51
69
loss = 'categorical_crossentropy' ,
52
70
optimizer = 'sgd' ,
@@ -64,9 +82,11 @@ def testClusterSimpleDenseModel(self, distribution):
64
82
weights_as_list = stripped_model .layers [0 ].kernel .numpy ().reshape (
65
83
- 1 ,).tolist ()
66
84
unique_weights = set (weights_as_list )
67
- self .assertLessEqual (len (unique_weights ), self . params [ ' number_of_clusters' ])
85
+ self .assertLessEqual (len (unique_weights ), clustering [ " number_of_clusters" ])
68
86
69
- @parameterized .parameters (_distribution_strategies ())
87
+ @parameterized .parameters (
88
+ _distribution_strategies ()
89
+ )
70
90
def testAssociationValuesPerReplica (self , distribution ):
71
91
"""Verifies that associations of weights are updated per replica."""
72
92
assert tf .distribute .get_replica_context () is not None
@@ -76,8 +96,9 @@ def testAssociationValuesPerReplica(self, distribution):
76
96
output_shape = (2 , 8 )
77
97
l = cluster_wrapper .ClusterWeights (
78
98
keras .layers .Dense (8 , input_shape = input_shape ),
79
- number_of_clusters = self .params ['number_of_clusters' ],
80
- cluster_centroids_init = self .params ['cluster_centroids_init' ])
99
+ number_of_clusters = 2 ,
100
+ cluster_centroids_init = CentroidInitialization .LINEAR
101
+ )
81
102
l .build (input_shape )
82
103
83
104
clusterable_weights = l .layer .get_clusterable_weights ()
0 commit comments