Skip to content

Commit 24a6ac0

Browse files
committed
Add Anti Zero-Drift functionality for Sparsity-Aware clustering (experimental)
* Created new experimental API for sparsity-aware clustering * Kept the original API implementation * Moved the new feature to a new experimental package, making the original implementation private * Updated the unit tests accordingly * Created init and BUILD files for the new experimental package
1 parent 0fdaf14 commit 24a6ac0

File tree

9 files changed

+262
-10
lines changed

9 files changed

+262
-10
lines changed

tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Module containing clustering code built on Keras abstractions."""
1616
# pylint: disable=g-bad-import-order
17+
from tensorflow_model_optimization.python.core.clustering.keras import experimental
18+
1719
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope
1820
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights
1921
from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing experimental clustering code built on Keras abstractions."""
16+
from tensorflow_model_optimization.python.core.clustering.keras.experimental.cluster import cluster_weights

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
srcs_version = "PY3",
1313
deps = [
1414
":cluster",
15+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental",
1516
],
1617
)
1718

@@ -90,6 +91,7 @@ py_test(
9091
visibility = ["//visibility:public"],
9192
deps = [
9293
":cluster",
94+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
9395
# tensorflow dep1,
9496
],
9597
)
@@ -146,6 +148,7 @@ py_test(
146148
":cluster",
147149
# tensorflow dep1,
148150
"//tensorflow_model_optimization/python/core/keras:compat",
151+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
149152
],
150153
)
151154

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def cluster_weights(to_cluster,
8181
```python
8282
clustering_params = {
8383
'number_of_clusters': 8,
84-
'cluster_centroids_init':
85-
CentroidInitialization.DENSITY_BASED
84+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
8685
}
8786
8887
clustered_model = cluster_weights(original_model, **clustering_params)
@@ -93,8 +92,7 @@ def cluster_weights(to_cluster,
9392
```python
9493
clustering_params = {
9594
'number_of_clusters': 8,
96-
'cluster_centroids_init':
97-
CentroidInitialization.DENSITY_BASED
95+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
9896
}
9997
10098
model = keras.Sequential([
@@ -128,6 +126,98 @@ def cluster_weights(to_cluster,
128126
Returns:
129127
Layer or model modified to include clustering related metadata.
130128
129+
Raises:
130+
ValueError: if the keras layer is unsupported, or the keras model contains
131+
an unsupported layer.
132+
"""
133+
return _cluster_weights(to_cluster,
134+
number_of_clusters,
135+
cluster_centroids_init,
136+
preserve_sparsity=False,
137+
**kwargs)
138+
139+
140+
def _cluster_weights(to_cluster,
141+
number_of_clusters,
142+
cluster_centroids_init,
143+
preserve_sparsity,
144+
**kwargs):
145+
"""Modify a keras layer or model to be clustered during training (private method).
146+
147+
This function wraps a keras model or layer with clustering functionality
148+
which clusters the layer's weights during training. For examples, using
149+
this with number_of_clusters equals 8 will ensure that each weight tensor has
150+
no more than 8 unique values.
151+
152+
Before passing to the clustering API, a model should already be trained and
153+
show some acceptable performance on the testing/validation sets.
154+
155+
The function accepts either a single keras layer
156+
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
157+
(instance of `keras.models.Model`) and handles them appropriately.
158+
159+
If it encounters a layer it does not know how to handle, it will throw an
160+
error. While clustering an entire model, even a single unknown layer would
161+
lead to an error.
162+
163+
Cluster a model:
164+
165+
```python
166+
clustering_params = {
167+
'number_of_clusters': 8,
168+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
169+
'preserve_sparsity': False
170+
}
171+
172+
clustered_model = cluster_weights(original_model, **clustering_params)
173+
```
174+
175+
Cluster a layer:
176+
177+
```python
178+
clustering_params = {
179+
'number_of_clusters': 8,
180+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
181+
'preserve_sparsity': False
182+
}
183+
184+
model = keras.Sequential([
185+
layers.Dense(10, activation='relu', input_shape=(100,)),
186+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
187+
])
188+
```
189+
190+
Cluster a layer with sparsity preservation (experimental):
191+
192+
```python
193+
clustering_params = {
194+
'number_of_clusters': 8,
195+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
196+
'preserve_sparsity': True
197+
}
198+
199+
model = keras.Sequential([
200+
layers.Dense(10, activation='relu', input_shape=(100,)),
201+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
202+
])
203+
```
204+
205+
Arguments:
206+
to_cluster: A single keras layer, list of keras layers, or a
207+
`tf.keras.Model` instance.
208+
number_of_clusters: the number of cluster centroids to form when
209+
clustering a layer/model. For example, if number_of_clusters=8 then only
210+
8 unique values will be used in each weight array.
211+
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
212+
instance that determines how the cluster centroids will be initialized.
213+
preserve_sparsity (experimental): optional boolean value that determines whether or not
214+
sparsity preservation will be enforced during training.
215+
**kwargs: Additional keyword arguments to be passed to the keras layer.
216+
Ignored when to_cluster is not a keras layer.
217+
218+
Returns:
219+
Layer or model modified to include clustering related metadata.
220+
131221
Raises:
132222
ValueError: if the keras layer is unsupported, or the keras model contains
133223
an unsupported layer.

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2727
from tensorflow_model_optimization.python.core.keras import compat
2828

29+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
30+
2931
keras = tf.keras
3032
layers = keras.layers
3133
test = tf.test
@@ -159,7 +161,7 @@ def testSparsityIsPreservedDuringTraining(self):
159161
"preserve_sparsity": True
160162
}
161163

162-
clustered_model = cluster.cluster_weights(original_model, **clustering_params)
164+
clustered_model = experimental_cluster.cluster_weights(original_model, **clustering_params)
163165

164166
stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
165167
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2727
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
2828

29+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
30+
2931
keras = tf.keras
3032
errors_impl = tf.errors
3133
layers = keras.layers
@@ -120,7 +122,7 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self):
120122
"""
121123
preserve_sparsity_params = { 'preserve_sparsity': True }
122124
params = { **self.params, **preserve_sparsity_params }
123-
wrapped_layer = cluster.cluster_weights(self.keras_clusterable_layer, **params)
125+
wrapped_layer = experimental_cluster.cluster_weights(self.keras_clusterable_layer, **params)
124126

125127
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
126128

@@ -184,7 +186,7 @@ def testClusterCustomClusterableLayerWithSparsityPreservation(self):
184186
"""
185187
preserve_sparsity_params = { 'preserve_sparsity': True }
186188
params = { **self.params, **preserve_sparsity_params }
187-
wrapped_layer = cluster.cluster_weights(self.custom_clusterable_layer, **params)
189+
wrapped_layer = experimental_cluster.cluster_weights(self.custom_clusterable_layer, **params)
188190
self.model.add(wrapped_layer)
189191
self.model.build(input_shape=(10, 1))
190192

@@ -230,7 +232,7 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
230232
preserve_sparsity_params = { 'preserve_sparsity': True }
231233
params = { **self.params, **preserve_sparsity_params }
232234
clustered_model = keras.Sequential()
233-
clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **params))
235+
clustered_model.add(experimental_cluster.cluster_weights(self.keras_clusterable_layer, **params))
234236
clustered_model.add(self.keras_clusterable_layer)
235237
clustered_model.build(input_shape=(1, 10))
236238

@@ -263,7 +265,7 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
263265
params = { **self.params, **preserve_sparsity_params }
264266
i1 = keras.Input(shape=(10,))
265267
i2 = keras.Input(shape=(10,))
266-
x1 = cluster.cluster_weights(layers.Dense(10), **params)(i1)
268+
x1 = experimental_cluster.cluster_weights(layers.Dense(10), **params)(i1)
267269
x2 = layers.Dense(10)(i2)
268270
outputs = layers.Add()([x1, x2])
269271
clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs)
@@ -302,7 +304,7 @@ def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self):
302304
self.keras_non_clusterable_layer,
303305
self.custom_clusterable_layer
304306
])
305-
clustered_model = cluster.cluster_weights(model, **params)
307+
clustered_model = experimental_cluster.cluster_weights(model, **params)
306308
clustered_model.build(input_shape=(1, 28, 28, 1))
307309

308310
self.assertEqual(len(model.layers), len(clustered_model.layers))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package(default_visibility = [
2+
"//tensorflow_model_optimization:__subpackages__",
3+
])
4+
5+
licenses(["notice"]) # Apache 2.0
6+
7+
exports_files(["LICENSE"])
8+
9+
py_library(
10+
name = "experimental",
11+
srcs = [
12+
"__init__.py",
13+
],
14+
srcs_version = "PY3",
15+
deps = [
16+
":cluster",
17+
],
18+
)
19+
20+
py_library(
21+
name = "cluster",
22+
srcs = ["cluster.py"],
23+
srcs_version = "PY3",
24+
visibility = ["//visibility:public"],
25+
deps = [
26+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster",
27+
],
28+
)

tensorflow_model_optimization/python/core/clustering/keras/experimental/__init__.py

Whitespace-only changes.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Experimental clustering API functions for Keras models."""
16+
17+
from tensorflow_model_optimization.python.core.clustering.keras.cluster import _cluster_weights
18+
19+
20+
def cluster_weights(to_cluster,
21+
number_of_clusters,
22+
cluster_centroids_init,
23+
preserve_sparsity,
24+
**kwargs):
25+
"""Modify a keras layer or model to be clustered during training (experimental).
26+
27+
This function wraps a keras model or layer with clustering functionality
28+
which clusters the layer's weights during training. For examples, using
29+
this with number_of_clusters equals 8 will ensure that each weight tensor has
30+
no more than 8 unique values.
31+
32+
Before passing to the clustering API, a model should already be trained and
33+
show some acceptable performance on the testing/validation sets.
34+
35+
The function accepts either a single keras layer
36+
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
37+
(instance of `keras.models.Model`) and handles them appropriately.
38+
39+
If it encounters a layer it does not know how to handle, it will throw an
40+
error. While clustering an entire model, even a single unknown layer would
41+
lead to an error.
42+
43+
Cluster a model:
44+
45+
```python
46+
clustering_params = {
47+
'number_of_clusters': 8,
48+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
49+
'preserve_sparsity': False
50+
}
51+
52+
clustered_model = cluster_weights(original_model, **clustering_params)
53+
```
54+
55+
Cluster a layer:
56+
57+
```python
58+
clustering_params = {
59+
'number_of_clusters': 8,
60+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
61+
'preserve_sparsity': False
62+
}
63+
64+
model = keras.Sequential([
65+
layers.Dense(10, activation='relu', input_shape=(100,)),
66+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
67+
])
68+
```
69+
70+
Cluster a layer with sparsity preservation:
71+
72+
```python
73+
clustering_params = {
74+
'number_of_clusters': 8,
75+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
76+
'preserve_sparsity': True
77+
}
78+
79+
model = keras.Sequential([
80+
layers.Dense(10, activation='relu', input_shape=(100,)),
81+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
82+
])
83+
```
84+
85+
Arguments:
86+
to_cluster: A single keras layer, list of keras layers, or a
87+
`tf.keras.Model` instance.
88+
number_of_clusters: the number of cluster centroids to form when
89+
clustering a layer/model. For example, if number_of_clusters=8 then only
90+
8 unique values will be used in each weight array.
91+
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
92+
instance that determines how the cluster centroids will be initialized.
93+
preserve_sparsity: optional boolean value that determines whether or not
94+
sparsity preservation will be enforced during training.
95+
**kwargs: Additional keyword arguments to be passed to the keras layer.
96+
Ignored when to_cluster is not a keras layer.
97+
98+
Returns:
99+
Layer or model modified to include clustering related metadata.
100+
101+
Raises:
102+
ValueError: if the keras layer is unsupported, or the keras model contains
103+
an unsupported layer.
104+
"""
105+
return _cluster_weights(to_cluster,
106+
number_of_clusters,
107+
cluster_centroids_init,
108+
preserve_sparsity,
109+
**kwargs)

0 commit comments

Comments
 (0)