Skip to content

Commit 27b267a

Browse files
Merge pull request #806 from MohamedNourArm:toupstream/update_cluster_guide
PiperOrigin-RevId: 393948588
2 parents eefc66a + 8168535 commit 27b267a

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@
214214
"\n",
215215
"clustering_params = {\n",
216216
" 'number_of_clusters': 3,\n",
217-
" 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED\n",
217+
" 'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS\n",
218218
"}\n",
219219
"\n",
220220
"model = setup_model()\n",
@@ -279,6 +279,61 @@
279279
"clustered_model.summary()"
280280
]
281281
},
282+
{
283+
"cell_type": "markdown",
284+
"metadata": {
285+
"id": "WcFrw1dHmxTr"
286+
},
287+
"source": [
288+
"### Cluster custom Keras layer or specify which weights of layer to cluster\n",
289+
"\n",
290+
"`tfmot.clustering.keras.ClusterableLayer` serves two use cases:\n",
291+
"1. Cluster any layer that is not supported natively, including a custom Keras layer.\n",
292+
"2. Specify which weights of a supported layer are to be clustered.\n",
293+
"\n",
294+
"For an example, the API defaults to only clustering the kernel of the\n",
295+
"`Dense` layer. The example below shows how to modify it to also cluster the bias. Note that when deriving from the keras layer, you need to override the function `get_clusterable_weights`, where you specify the name of the trainable variable to be clustered and the trainable variable itself. For example, if you return an empty list [], then no weights will be clusterable.\n",
296+
"\n",
297+
"**Common mistake:** clustering the bias usually harms model accuracy too much."
298+
]
299+
},
300+
{
301+
"cell_type": "code",
302+
"execution_count": null,
303+
"metadata": {
304+
"id": "73iboQ7MmxTs"
305+
},
306+
"outputs": [],
307+
"source": [
308+
"class MyDenseLayer(tf.keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):\n",
309+
"\n",
310+
" def get_clusterable_weights(self):\n",
311+
" # Cluster kernel and bias. This is just an example, clustering\n",
312+
" # bias usually hurts model accuracy.\n",
313+
" return [('kernel', self.kernel), ('bias', self.bias)]\n",
314+
"\n",
315+
"# Use `cluster_weights` to make the `MyDenseLayer` layer train with clustering as usual.\n",
316+
"model_for_clustering = tf.keras.Sequential([\n",
317+
" tfmot.clustering.keras.cluster_weights(MyDenseLayer(20, input_shape=[input_dim]), **clustering_params),\n",
318+
" tf.keras.layers.Flatten()\n",
319+
"])\n",
320+
"\n",
321+
"model_for_clustering.summary()"
322+
]
323+
},
324+
{
325+
"cell_type": "markdown",
326+
"metadata": {
327+
"id": "SYlWPXEWmxTs"
328+
},
329+
"source": [
330+
"You may also use `tfmot.clustering.keras.ClusterableLayer` to cluster a keras custom layer. To do this, you extend `tf.keras.Layer` as usual and implement the `__init__`, `call`, and `build` functions, but you also need to extend the `clusterable_layer.ClusterableLayer` class and implement:\n",
331+
"1. `get_clusterable_weights`, where you specify the weight kernel to be clustered, as shown above.\n",
332+
"2. `get_clusterable_algorithm`, where you specify the clustering algorithm for the weight tensor. This is because you need to specify how the custom layer weights are shaped for clustering. The returned clustering algorithm class should be derived from the `clustering_algorithm.ClusteringAlgorithm` class and the function `get_pulling_indices` should be overwritten. An example of this function, which supports weights of ranks 1D, 2D, and 3D, can be found [here]( https://github.com/tensorflow/model-optimization/blob/18e87d262e536c9a742aef700880e71b47a7f768/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py#L62).\n",
333+
"\n",
334+
"An example of this use case can be found [here](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py)."
335+
]
336+
},
282337
{
283338
"cell_type": "markdown",
284339
"metadata": {
@@ -338,7 +393,7 @@
338393
"source": [
339394
"For your specific use case, there are tips you can consider:\n",
340395
"\n",
341-
"* Centroid initialization plays a key role in the final optimized model accuracy. In general, linear initialization outperforms density and random initialization since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.\n",
396+
"* Centroid initialization plays a key role in the final optimized model accuracy. In general, kmeans++ initialization outperforms linear, density and random initialization. When not using kmeans++, linear initialization tends to outperform density and random initialization, since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.\n",
342397
"\n",
343398
"* Set a learning rate that is lower than the one used in training when fine-tuning the clustered model.\n",
344399
"\n",

0 commit comments

Comments
 (0)