Skip to content

Commit e4a5200

Browse files
Merge pull request #520 from MatteoArm:feature/sparsity_aware_clustering
PiperOrigin-RevId: 351486381
2 parents ef6aa21 + a12e20e commit e4a5200

File tree

13 files changed

+921
-438
lines changed

13 files changed

+921
-438
lines changed

tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Module containing clustering code built on Keras abstractions."""
1616
# pylint: disable=g-bad-import-order
17+
from tensorflow_model_optimization.python.core.clustering.keras import experimental
18+
1719
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope
1820
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights
1921
from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing experimental clustering code built on Keras abstractions."""
16+
from tensorflow_model_optimization.python.core.clustering.keras.experimental.cluster import cluster_weights

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package(default_visibility = [
22
"//tensorflow_model_optimization:__subpackages__",
33
])
44

5-
licenses(["notice"]) # Apache 2.0
5+
licenses(["notice"])
66

77
py_library(
88
name = "keras",
@@ -12,6 +12,7 @@ py_library(
1212
srcs_version = "PY3",
1313
deps = [
1414
":cluster",
15+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental",
1516
],
1617
)
1718

@@ -92,6 +93,7 @@ py_test(
9293
deps = [
9394
":cluster",
9495
# tensorflow dep1,
96+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
9597
],
9698
)
9799

@@ -146,6 +148,7 @@ py_test(
146148
deps = [
147149
":cluster",
148150
# tensorflow dep1,
151+
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
149152
"//tensorflow_model_optimization/python/core/keras:compat",
150153
],
151154
)

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 118 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
"""Clustering API functions for Keras models."""
1616

1717
from tensorflow import keras
18-
from tensorflow.keras import initializers
1918

20-
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2119
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2220
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
2321

@@ -57,7 +55,7 @@ def cluster_weights(to_cluster,
5755
number_of_clusters,
5856
cluster_centroids_init,
5957
**kwargs):
60-
"""Modify a keras layer or model to be clustered during training.
58+
"""Modifies a keras layer or model to be clustered during training.
6159
6260
This function wraps a keras model or layer with clustering functionality
6361
which clusters the layer's weights during training. For examples, using
@@ -80,8 +78,7 @@ def cluster_weights(to_cluster,
8078
```python
8179
clustering_params = {
8280
'number_of_clusters': 8,
83-
'cluster_centroids_init':
84-
CentroidInitialization.DENSITY_BASED
81+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
8582
}
8683
8784
clustered_model = cluster_weights(original_model, **clustering_params)
@@ -92,8 +89,104 @@ def cluster_weights(to_cluster,
9289
```python
9390
clustering_params = {
9491
'number_of_clusters': 8,
95-
'cluster_centroids_init':
96-
CentroidInitialization.DENSITY_BASED
92+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
93+
}
94+
95+
model = keras.Sequential([
96+
layers.Dense(10, activation='relu', input_shape=(100,)),
97+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
98+
])
99+
```
100+
101+
Arguments:
102+
to_cluster: A single keras layer, list of keras layers, or a
103+
`tf.keras.Model` instance.
104+
number_of_clusters: the number of cluster centroids to form when
105+
clustering a layer/model. For example, if number_of_clusters=8 then only
106+
8 unique values will be used in each weight array.
107+
cluster_centroids_init: enum value that determines how the cluster
108+
centroids will be initialized.
109+
Can have following values:
110+
1. RANDOM : centroids are sampled using the uniform distribution
111+
between the minimum and maximum weight values in a given layer
112+
2. DENSITY_BASED : density-based sampling. First, cumulative
113+
distribution function is built for weights, then y-axis is evenly
114+
spaced into number_of_clusters regions. After this the corresponding
115+
x values are obtained and used to initialize clusters centroids.
116+
3. LINEAR : cluster centroids are evenly spaced between the minimum
117+
and maximum values of a given weight
118+
**kwargs: Additional keyword arguments to be passed to the keras layer.
119+
Ignored when to_cluster is not a keras layer.
120+
121+
Returns:
122+
Layer or model modified to include clustering related metadata.
123+
124+
Raises:
125+
ValueError: if the keras layer is unsupported, or the keras model contains
126+
an unsupported layer.
127+
"""
128+
return _cluster_weights(
129+
to_cluster,
130+
number_of_clusters,
131+
cluster_centroids_init,
132+
preserve_sparsity=False,
133+
**kwargs)
134+
135+
136+
def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
137+
preserve_sparsity, **kwargs):
138+
"""Modifies a keras layer or model to be clustered during training.
139+
140+
This function wraps a keras model or layer with clustering functionality
141+
which clusters the layer's weights during training. For examples, using
142+
this with number_of_clusters equals 8 will ensure that each weight tensor has
143+
no more than 8 unique values.
144+
145+
Before passing to the clustering API, a model should already be trained and
146+
show some acceptable performance on the testing/validation sets.
147+
148+
The function accepts either a single keras layer
149+
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
150+
(instance of `keras.models.Model`) and handles them appropriately.
151+
152+
If it encounters a layer it does not know how to handle, it will throw an
153+
error. While clustering an entire model, even a single unknown layer would
154+
lead to an error.
155+
156+
Cluster a model:
157+
158+
```python
159+
clustering_params = {
160+
'number_of_clusters': 8,
161+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
162+
'preserve_sparsity': False
163+
}
164+
165+
clustered_model = cluster_weights(original_model, **clustering_params)
166+
```
167+
168+
Cluster a layer:
169+
170+
```python
171+
clustering_params = {
172+
'number_of_clusters': 8,
173+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
174+
'preserve_sparsity': False
175+
}
176+
177+
model = keras.Sequential([
178+
layers.Dense(10, activation='relu', input_shape=(100,)),
179+
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
180+
])
181+
```
182+
183+
Cluster a layer with sparsity preservation (experimental):
184+
185+
```python
186+
clustering_params = {
187+
'number_of_clusters': 8,
188+
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
189+
'preserve_sparsity': True
97190
}
98191
99192
model = keras.Sequential([
@@ -110,6 +203,8 @@ def cluster_weights(to_cluster,
110203
8 unique values will be used in each weight array.
111204
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
112205
instance that determines how the cluster centroids will be initialized.
206+
preserve_sparsity (experimental): optional boolean value that determines
207+
whether or not sparsity preservation will be enforced during training.
113208
**kwargs: Additional keyword arguments to be passed to the keras layer.
114209
Ignored when to_cluster is not a keras layer.
115210
@@ -120,33 +215,31 @@ def cluster_weights(to_cluster,
120215
ValueError: if the keras layer is unsupported, or the keras model contains
121216
an unsupported layer.
122217
"""
123-
if not clustering_centroids.CentroidsInitializerFactory.\
124-
init_is_supported(cluster_centroids_init):
125-
raise ValueError("Cluster centroid initialization {} not supported".\
126-
format(cluster_centroids_init))
218+
if not clustering_centroids.CentroidsInitializerFactory.init_is_supported(
219+
cluster_centroids_init):
220+
raise ValueError('Cluster centroid initialization {} not supported'.format(
221+
cluster_centroids_init))
127222

128223
def _add_clustering_wrapper(layer):
129-
130-
if (isinstance(layer, keras.Model)):
224+
if isinstance(layer, keras.Model):
131225
# Check whether the model is a subclass.
132226
# NB: This check is copied from keras.py file in tensorflow.
133227
# There is no available public API to do this check.
228+
# pylint: disable=protected-access
134229
if (not layer._is_graph_network and
135230
not isinstance(layer, keras.models.Sequential)):
136-
raise ValueError("Subclassed models are not supported currently.")
231+
raise ValueError('Subclassed models are not supported currently.')
137232

138-
return keras.models.clone_model(layer,
139-
input_tensors=None,
140-
clone_function=_add_clustering_wrapper)
233+
return keras.models.clone_model(
234+
layer, input_tensors=None, clone_function=_add_clustering_wrapper)
141235
if isinstance(layer, cluster_wrapper.ClusterWeights):
142236
return layer
143237
if isinstance(layer, InputLayer):
144238
return layer.__class__.from_config(layer.get_config())
145239

146-
return cluster_wrapper.ClusterWeights(layer,
147-
number_of_clusters,
240+
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
148241
cluster_centroids_init,
149-
**kwargs)
242+
preserve_sparsity, **kwargs)
150243

151244
def _wrap_list(layers):
152245
output = []
@@ -166,7 +259,7 @@ def _wrap_list(layers):
166259

167260

168261
def strip_clustering(model):
169-
"""Strip clustering wrappers from the model.
262+
"""Strips clustering wrappers from the model.
170263
171264
Once a model has been clustered, this method can be used
172265
to restore the original model with the clustered weights.
@@ -198,16 +291,17 @@ def strip_clustering(model):
198291

199292
def _strip_clustering_wrapper(layer):
200293
if isinstance(layer, keras.Model):
201-
return keras.models.clone_model(layer,
202-
input_tensors=None,
203-
clone_function=_strip_clustering_wrapper)
294+
return keras.models.clone_model(
295+
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
204296
elif isinstance(layer, cluster_wrapper.ClusterWeights):
205297
if not hasattr(layer.layer, '_batch_input_shape') and\
206298
hasattr(layer, '_batch_input_shape'):
299+
# pylint: disable=protected-access
207300
layer.layer._batch_input_shape = layer._batch_input_shape
208301

209302
# We reset both arrays of weights, so that we can guarantee the correct
210303
# order of newly created weights
304+
# pylint: disable=protected-access
211305
layer.layer._trainable_weights = []
212306
layer.layer._non_trainable_weights = []
213307
for i in range(len(layer.restore)):

0 commit comments

Comments
 (0)