15
15
"""Clustering API functions for Keras models."""
16
16
17
17
from tensorflow import keras
18
- from tensorflow .keras import initializers
19
18
20
- from tensorflow_model_optimization .python .core .clustering .keras import cluster_config
21
19
from tensorflow_model_optimization .python .core .clustering .keras import cluster_wrapper
22
20
from tensorflow_model_optimization .python .core .clustering .keras import clustering_centroids
23
21
@@ -57,7 +55,7 @@ def cluster_weights(to_cluster,
57
55
number_of_clusters ,
58
56
cluster_centroids_init ,
59
57
** kwargs ):
60
- """Modify a keras layer or model to be clustered during training.
58
+ """Modifies a keras layer or model to be clustered during training.
61
59
62
60
This function wraps a keras model or layer with clustering functionality
63
61
which clusters the layer's weights during training. For examples, using
@@ -80,8 +78,7 @@ def cluster_weights(to_cluster,
80
78
```python
81
79
clustering_params = {
82
80
'number_of_clusters': 8,
83
- 'cluster_centroids_init':
84
- CentroidInitialization.DENSITY_BASED
81
+ 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
85
82
}
86
83
87
84
clustered_model = cluster_weights(original_model, **clustering_params)
@@ -92,8 +89,104 @@ def cluster_weights(to_cluster,
92
89
```python
93
90
clustering_params = {
94
91
'number_of_clusters': 8,
95
- 'cluster_centroids_init':
96
- CentroidInitialization.DENSITY_BASED
92
+ 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
93
+ }
94
+
95
+ model = keras.Sequential([
96
+ layers.Dense(10, activation='relu', input_shape=(100,)),
97
+ cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
98
+ ])
99
+ ```
100
+
101
+ Arguments:
102
+ to_cluster: A single keras layer, list of keras layers, or a
103
+ `tf.keras.Model` instance.
104
+ number_of_clusters: the number of cluster centroids to form when
105
+ clustering a layer/model. For example, if number_of_clusters=8 then only
106
+ 8 unique values will be used in each weight array.
107
+ cluster_centroids_init: enum value that determines how the cluster
108
+ centroids will be initialized.
109
+ Can have following values:
110
+ 1. RANDOM : centroids are sampled using the uniform distribution
111
+ between the minimum and maximum weight values in a given layer
112
+ 2. DENSITY_BASED : density-based sampling. First, cumulative
113
+ distribution function is built for weights, then y-axis is evenly
114
+ spaced into number_of_clusters regions. After this the corresponding
115
+ x values are obtained and used to initialize clusters centroids.
116
+ 3. LINEAR : cluster centroids are evenly spaced between the minimum
117
+ and maximum values of a given weight
118
+ **kwargs: Additional keyword arguments to be passed to the keras layer.
119
+ Ignored when to_cluster is not a keras layer.
120
+
121
+ Returns:
122
+ Layer or model modified to include clustering related metadata.
123
+
124
+ Raises:
125
+ ValueError: if the keras layer is unsupported, or the keras model contains
126
+ an unsupported layer.
127
+ """
128
+ return _cluster_weights (
129
+ to_cluster ,
130
+ number_of_clusters ,
131
+ cluster_centroids_init ,
132
+ preserve_sparsity = False ,
133
+ ** kwargs )
134
+
135
+
136
+ def _cluster_weights (to_cluster , number_of_clusters , cluster_centroids_init ,
137
+ preserve_sparsity , ** kwargs ):
138
+ """Modifies a keras layer or model to be clustered during training.
139
+
140
+ This function wraps a keras model or layer with clustering functionality
141
+ which clusters the layer's weights during training. For examples, using
142
+ this with number_of_clusters equals 8 will ensure that each weight tensor has
143
+ no more than 8 unique values.
144
+
145
+ Before passing to the clustering API, a model should already be trained and
146
+ show some acceptable performance on the testing/validation sets.
147
+
148
+ The function accepts either a single keras layer
149
+ (subclass of `keras.layers.Layer`), list of keras layers or a keras model
150
+ (instance of `keras.models.Model`) and handles them appropriately.
151
+
152
+ If it encounters a layer it does not know how to handle, it will throw an
153
+ error. While clustering an entire model, even a single unknown layer would
154
+ lead to an error.
155
+
156
+ Cluster a model:
157
+
158
+ ```python
159
+ clustering_params = {
160
+ 'number_of_clusters': 8,
161
+ 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
162
+ 'preserve_sparsity': False
163
+ }
164
+
165
+ clustered_model = cluster_weights(original_model, **clustering_params)
166
+ ```
167
+
168
+ Cluster a layer:
169
+
170
+ ```python
171
+ clustering_params = {
172
+ 'number_of_clusters': 8,
173
+ 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
174
+ 'preserve_sparsity': False
175
+ }
176
+
177
+ model = keras.Sequential([
178
+ layers.Dense(10, activation='relu', input_shape=(100,)),
179
+ cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
180
+ ])
181
+ ```
182
+
183
+ Cluster a layer with sparsity preservation (experimental):
184
+
185
+ ```python
186
+ clustering_params = {
187
+ 'number_of_clusters': 8,
188
+ 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
189
+ 'preserve_sparsity': True
97
190
}
98
191
99
192
model = keras.Sequential([
@@ -110,6 +203,8 @@ def cluster_weights(to_cluster,
110
203
8 unique values will be used in each weight array.
111
204
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
112
205
instance that determines how the cluster centroids will be initialized.
206
+ preserve_sparsity (experimental): optional boolean value that determines
207
+ whether or not sparsity preservation will be enforced during training.
113
208
**kwargs: Additional keyword arguments to be passed to the keras layer.
114
209
Ignored when to_cluster is not a keras layer.
115
210
@@ -120,33 +215,31 @@ def cluster_weights(to_cluster,
120
215
ValueError: if the keras layer is unsupported, or the keras model contains
121
216
an unsupported layer.
122
217
"""
123
- if not clustering_centroids .CentroidsInitializerFactory .\
124
- init_is_supported ( cluster_centroids_init ):
125
- raise ValueError (" Cluster centroid initialization {} not supported" .\
126
- format ( cluster_centroids_init ))
218
+ if not clustering_centroids .CentroidsInitializerFactory .init_is_supported (
219
+ cluster_centroids_init ):
220
+ raise ValueError (' Cluster centroid initialization {} not supported' . format (
221
+ cluster_centroids_init ))
127
222
128
223
def _add_clustering_wrapper (layer ):
129
-
130
- if (isinstance (layer , keras .Model )):
224
+ if isinstance (layer , keras .Model ):
131
225
# Check whether the model is a subclass.
132
226
# NB: This check is copied from keras.py file in tensorflow.
133
227
# There is no available public API to do this check.
228
+ # pylint: disable=protected-access
134
229
if (not layer ._is_graph_network and
135
230
not isinstance (layer , keras .models .Sequential )):
136
- raise ValueError (" Subclassed models are not supported currently." )
231
+ raise ValueError (' Subclassed models are not supported currently.' )
137
232
138
- return keras .models .clone_model (layer ,
139
- input_tensors = None ,
140
- clone_function = _add_clustering_wrapper )
233
+ return keras .models .clone_model (
234
+ layer , input_tensors = None , clone_function = _add_clustering_wrapper )
141
235
if isinstance (layer , cluster_wrapper .ClusterWeights ):
142
236
return layer
143
237
if isinstance (layer , InputLayer ):
144
238
return layer .__class__ .from_config (layer .get_config ())
145
239
146
- return cluster_wrapper .ClusterWeights (layer ,
147
- number_of_clusters ,
240
+ return cluster_wrapper .ClusterWeights (layer , number_of_clusters ,
148
241
cluster_centroids_init ,
149
- ** kwargs )
242
+ preserve_sparsity , ** kwargs )
150
243
151
244
def _wrap_list (layers ):
152
245
output = []
@@ -166,7 +259,7 @@ def _wrap_list(layers):
166
259
167
260
168
261
def strip_clustering (model ):
169
- """Strip clustering wrappers from the model.
262
+ """Strips clustering wrappers from the model.
170
263
171
264
Once a model has been clustered, this method can be used
172
265
to restore the original model with the clustered weights.
@@ -198,16 +291,17 @@ def strip_clustering(model):
198
291
199
292
def _strip_clustering_wrapper (layer ):
200
293
if isinstance (layer , keras .Model ):
201
- return keras .models .clone_model (layer ,
202
- input_tensors = None ,
203
- clone_function = _strip_clustering_wrapper )
294
+ return keras .models .clone_model (
295
+ layer , input_tensors = None , clone_function = _strip_clustering_wrapper )
204
296
elif isinstance (layer , cluster_wrapper .ClusterWeights ):
205
297
if not hasattr (layer .layer , '_batch_input_shape' ) and \
206
298
hasattr (layer , '_batch_input_shape' ):
299
+ # pylint: disable=protected-access
207
300
layer .layer ._batch_input_shape = layer ._batch_input_shape
208
301
209
302
# We reset both arrays of weights, so that we can guarantee the correct
210
303
# order of newly created weights
304
+ # pylint: disable=protected-access
211
305
layer .layer ._trainable_weights = []
212
306
layer .layer ._non_trainable_weights = []
213
307
for i in range (len (layer .restore )):
0 commit comments