Skip to content

Commit a396089

Browse files
Merge pull request #382 from arovir01:toupstream/new_clustering_params
PiperOrigin-RevId: 317739663
2 parents c791831 + 2e70a9e commit a396089

File tree

9 files changed

+138
-52
lines changed

9 files changed

+138
-52
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ py_library(
2323
srcs_version = "PY2AND3",
2424
visibility = ["//visibility:public"],
2525
deps = [
26+
":cluster_config",
2627
":cluster_wrapper",
2728
":clustering_centroids",
2829
":clustering_registry",
2930
],
3031
)
3132

33+
py_library(
34+
name = "cluster_config",
35+
srcs = ["cluster_config.py"],
36+
srcs_version = "PY2AND3",
37+
visibility = ["//visibility:public"],
38+
)
39+
3240
py_library(
3341
name = "clustering_registry",
3442
srcs = ["clustering_registry.py"],
@@ -51,13 +59,19 @@ py_library(
5159
srcs = ["clustering_centroids.py"],
5260
srcs_version = "PY2AND3",
5361
visibility = ["//visibility:public"],
62+
deps = [
63+
":cluster_config",
64+
],
5465
)
5566

5667
py_library(
5768
name = "cluster_wrapper",
5869
srcs = ["cluster_wrapper.py"],
5970
srcs_version = "PY2AND3",
6071
visibility = ["//visibility:public"],
72+
deps = [
73+
":cluster_config",
74+
],
6175
)
6276

6377
py_test(

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from tensorflow import keras
1818
from tensorflow.keras import initializers
1919

20-
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
20+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2121
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
22+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
2223

2324
k = keras.backend
2425
CustomObjectScope = keras.utils.CustomObjectScope
@@ -47,7 +48,7 @@ def cluster_scope():
4748
"""
4849
return CustomObjectScope(
4950
{
50-
'ClusterWeights': cluster_wrapper.ClusterWeights
51+
'ClusterWeights' : cluster_wrapper.ClusterWeights
5152
}
5253
)
5354

@@ -79,7 +80,8 @@ def cluster_weights(to_cluster,
7980
```python
8081
clustering_params = {
8182
'number_of_clusters': 8,
82-
'cluster_centroids_init': 'density-based'
83+
'cluster_centroids_init':
84+
cluster_config.CentroidInitialization.DENSITY_BASED
8385
}
8486
8587
clustered_model = cluster_weights(original_model, **clustering_params)
@@ -90,7 +92,8 @@ def cluster_weights(to_cluster,
9092
```python
9193
clustering_params = {
9294
'number_of_clusters': 8,
93-
'cluster_centroids_init': 'density-based'
95+
'cluster_centroids_init':
96+
cluster_config.CentroidInitialization.DENSITY_BASED
9497
}
9598
9699
model = keras.Sequential([
@@ -105,15 +108,16 @@ def cluster_weights(to_cluster,
105108
number_of_clusters: the number of cluster centroids to form when
106109
clustering a layer/model. For example, if number_of_clusters=8 then only
107110
8 unique values will be used in each weight array.
108-
cluster_centroids_init: how to initialize the cluster centroids.
111+
cluster_centroids_init: enum value that determines how the cluster
112+
centroids will be initialized.
109113
Can have following values:
110-
1. 'random' : centroids are sampled using the uniform distribution
114+
1. RANDOM : centroids are sampled using the uniform distribution
111115
between the minimum and maximum weight values in a given layer
112-
2. 'density-based' : density-based sampling. First, cumulative
116+
2. DENSITY_BASED : density-based sampling. First, cumulative
113117
distribution function is built for weights, then y-axis is evenly
114118
spaced into number_of_clusters regions. After this the corresponding x
115119
values are obtained and used to initialize clusters centroids.
116-
3. 'linear' : cluster centroids are evenly spaced between the minimum
120+
3. LINEAR : cluster centroids are evenly spaced between the minimum
117121
and maximum values of a given weight
118122
**kwargs: Additional keyword arguments to be passed to the keras layer.
119123
Ignored when to_cluster is not a keras layer.
@@ -127,8 +131,8 @@ def cluster_weights(to_cluster,
127131
"""
128132
if not clustering_centroids.CentroidsInitializerFactory.\
129133
init_is_supported(cluster_centroids_init):
130-
raise ValueError("cluster centroids can only be one of three values: "
131-
"random, density-based, linear")
134+
raise ValueError("Cluster centroid initialization {} not supported".\
135+
format(cluster_centroids_init))
132136

133137
def _add_clustering_wrapper(layer):
134138
if isinstance(layer, cluster_wrapper.ClusterWeights):
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Configuration classes for clustering."""
16+
17+
import enum
18+
19+
20+
class CentroidInitialization(str, enum.Enum):
21+
LINEAR = "LINEAR"
22+
RANDOM = "RANDOM"
23+
DENSITY_BASED = "DENSITY_BASED"

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919

2020
from absl.testing import parameterized
2121
from tensorflow.python.keras import keras_parameterized
22+
2223
from tensorflow_model_optimization.python.core.clustering.keras import cluster
24+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2325

2426
keras = tf.keras
2527
layers = keras.layers
2628
test = tf.test
2729

30+
CentroidInitialization = cluster_config.CentroidInitialization
2831

2932
class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
3033
"""Integration tests for clustering."""
@@ -43,7 +46,7 @@ def testValuesRemainClusteredAfterTraining(self):
4346
clustered_model = cluster.cluster_weights(
4447
original_model,
4548
number_of_clusters=number_of_clusters,
46-
cluster_centroids_init='linear'
49+
cluster_centroids_init=CentroidInitialization.LINEAR
4750
)
4851

4952
clustered_model.compile(

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow.python.keras import keras_parameterized
2222

2323
from tensorflow_model_optimization.python.core.clustering.keras import cluster
24+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2425
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2526
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2627
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
@@ -76,7 +77,8 @@ def setUp(self):
7677
self.model = keras.Sequential()
7778
self.params = {
7879
'number_of_clusters': 8,
79-
'cluster_centroids_init': 'density-based'
80+
'cluster_centroids_init':
81+
cluster_config.CentroidInitialization.DENSITY_BASED
8082
}
8183

8284
def _build_clustered_layer_model(self, layer):

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tensorflow.keras import initializers
2020

21+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2122
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2223
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
2324
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
@@ -273,7 +274,8 @@ def from_config(cls, config, custom_objects=None):
273274
number_of_clusters = config.pop('number_of_clusters')
274275
cluster_centroids_init = config.pop('cluster_centroids_init')
275276
config['number_of_clusters'] = number_of_clusters
276-
config['cluster_centroids_init'] = cluster_centroids_init
277+
config['cluster_centroids_init'] = cluster_config.CentroidInitialization(
278+
cluster_centroids_init)
277279

278280
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
279281
layer = deserialize_layer(config.pop('layer'),

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl.testing import parameterized
2222

2323
from tensorflow_model_optimization.python.core.clustering.keras import cluster
24+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2425
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2526
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2627
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
@@ -30,7 +31,7 @@
3031
layers = keras.layers
3132
test = tf.test
3233

33-
layers = keras.layers
34+
CentroidInitialization = cluster_config.CentroidInitialization
3435
ClusterRegistry = clustering_registry.ClusteringRegistry
3536
ClusteringLookupRegistry = clustering_registry.ClusteringLookupRegistry
3637

@@ -55,28 +56,34 @@ def testCannotBeInitializedWithNonLayerObject(self):
5556
not an instance of keras.layers.Layer.
5657
"""
5758
with self.assertRaises(ValueError):
58-
cluster_wrapper.ClusterWeights({
59-
'this': 'is not a Layer instance'
60-
}, number_of_clusters=13, cluster_centroids_init='linear')
59+
cluster_wrapper.ClusterWeights(
60+
{'this': 'is not a Layer instance'},
61+
number_of_clusters=13,
62+
cluster_centroids_init=CentroidInitialization.LINEAR
63+
)
6164

6265
def testCannotBeInitializedWithNonClusterableLayer(self):
6366
"""
6467
Verifies that ClusterWeights cannot be initialized with a non-clusterable
6568
custom layer.
6669
"""
6770
with self.assertRaises(ValueError):
68-
cluster_wrapper.ClusterWeights(NonClusterableLayer(10),
69-
number_of_clusters=13,
70-
cluster_centroids_init='linear')
71+
cluster_wrapper.ClusterWeights(
72+
NonClusterableLayer(10),
73+
number_of_clusters=13,
74+
cluster_centroids_init=CentroidInitialization.LINEAR
75+
)
7176

7277
def testCanBeInitializedWithClusterableLayer(self):
7378
"""
7479
Verifies that ClusterWeights can be initialized with a built-in clusterable
7580
layer.
7681
"""
77-
l = cluster_wrapper.ClusterWeights(layers.Dense(10),
78-
number_of_clusters=13,
79-
cluster_centroids_init='linear')
82+
l = cluster_wrapper.ClusterWeights(
83+
layers.Dense(10),
84+
number_of_clusters=13,
85+
cluster_centroids_init=CentroidInitialization.LINEAR
86+
)
8087
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
8188

8289
def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
@@ -85,19 +92,23 @@ def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
8592
provided for the number of clusters.
8693
"""
8794
with self.assertRaises(ValueError):
88-
cluster_wrapper.ClusterWeights(layers.Dense(10),
89-
number_of_clusters="13",
90-
cluster_centroids_init='linear')
95+
cluster_wrapper.ClusterWeights(
96+
layers.Dense(10),
97+
number_of_clusters="13",
98+
cluster_centroids_init=CentroidInitialization.LINEAR
99+
)
91100

92101
def testCannotBeInitializedWithFloatNumberOfClusters(self):
93102
"""
94103
Verifies that ClusterWeights cannot be initialized with a decimal value
95104
provided for the number of clusters.
96105
"""
97106
with self.assertRaises(ValueError):
98-
cluster_wrapper.ClusterWeights(layers.Dense(10),
99-
number_of_clusters=13.4,
100-
cluster_centroids_init='linear')
107+
cluster_wrapper.ClusterWeights(
108+
layers.Dense(10),
109+
number_of_clusters=13.4,
110+
cluster_centroids_init=CentroidInitialization.LINEAR
111+
)
101112

102113
@parameterized.parameters(
103114
(0),
@@ -111,34 +122,47 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo(
111122
clusters.
112123
"""
113124
with self.assertRaises(ValueError):
114-
cluster_wrapper.ClusterWeights(layers.Dense(10),
115-
number_of_clusters=number_of_clusters,
116-
cluster_centroids_init='linear')
125+
cluster_wrapper.ClusterWeights(
126+
layers.Dense(10),
127+
number_of_clusters=number_of_clusters,
128+
cluster_centroids_init=CentroidInitialization.LINEAR
129+
)
117130

118131
def testCanBeInitializedWithAlreadyClusterableLayer(self):
119132
"""
120133
Verifies that ClusterWeights can be initialized with a custom clusterable
121134
layer.
122135
"""
123136
layer = AlreadyClusterableLayer(10)
124-
l = cluster_wrapper.ClusterWeights(layer,
125-
number_of_clusters=13,
126-
cluster_centroids_init='linear')
137+
l = cluster_wrapper.ClusterWeights(
138+
layer,
139+
number_of_clusters=13,
140+
cluster_centroids_init=CentroidInitialization.LINEAR
141+
)
127142
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
128143

129144
def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
130145
"""
131146
Verifies that the ClusterWeights instance created from a layer that has
132147
a batch shape attribute, will also have this attribute.
133148
"""
134-
l = cluster_wrapper.ClusterWeights(layers.Dense(10, input_shape=(10,)),
135-
number_of_clusters=13,
136-
cluster_centroids_init='linear')
149+
l = cluster_wrapper.ClusterWeights(
150+
layers.Dense(10, input_shape=(10,)),
151+
number_of_clusters=13,
152+
cluster_centroids_init=CentroidInitialization.LINEAR
153+
)
137154
self.assertTrue(hasattr(l, '_batch_input_shape'))
138155

139156
# Makes it easier to test all possible parameters combinations.
140157
@parameterized.parameters(
141-
*itertools.product(range(2, 16, 4), ('linear', 'random', 'density-based'))
158+
*itertools.product(
159+
range(2, 16, 4),
160+
(
161+
CentroidInitialization.LINEAR,
162+
CentroidInitialization.RANDOM,
163+
CentroidInitialization.DENSITY_BASED
164+
)
165+
)
142166
)
143167
def testValuesAreClusteredAfterStripping(self,
144168
number_of_clusters,

tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import six
1919
import tensorflow as tf
2020

21-
k = tf.keras.backend
21+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2222

23+
k = tf.keras.backend
24+
CentroidInitialization = cluster_config.CentroidInitialization
2325

2426
@six.add_metaclass(abc.ABCMeta)
2527
class AbstractCentroidsInitialisation:
@@ -187,9 +189,10 @@ class CentroidsInitializerFactory:
187189
reflect new methods available.
188190
"""
189191
_initialisers = {
190-
'linear': LinearCentroidsInitialisation,
191-
'random': RandomCentroidsInitialisation,
192-
'density-based': DensityBasedCentroidsInitialisation
192+
CentroidInitialization.LINEAR : LinearCentroidsInitialisation,
193+
CentroidInitialization.RANDOM : RandomCentroidsInitialisation,
194+
CentroidInitialization.DENSITY_BASED :
195+
DensityBasedCentroidsInitialisation
193196
}
194197

195198
@classmethod
@@ -199,9 +202,11 @@ def init_is_supported(cls, init_method):
199202
@classmethod
200203
def get_centroid_initializer(cls, init_method):
201204
"""
202-
:param init_method: a string representation of the init methods requested
203-
:return: A concrete implementation of AbstractCentroidsInitialisation
204-
:raises: ValueError if the string representation is not recognised
205+
:param init_method: a CentroidInitialization value representing the init
206+
method requested
207+
:return: A concrete implementation of AbstractCentroidsInitialisation
208+
:raises: ValueError if the requested centroid initialization method is not
209+
recognised
205210
"""
206211
if not cls.init_is_supported(init_method):
207212
raise ValueError(

0 commit comments

Comments
 (0)