|
| 1 | +# Weight clustering |
| 2 | + |
| 3 | +This document provides an overview on weight clustering to help you determine how it fits with your use case. |
| 4 | + |
| 5 | +- To dive right into an end-to-end example, see the [weight clustering example](clustering_example.ipynb). |
| 6 | +- To quickly find the APIs you need for your use case, see the [weight clustering comprehensive guide](clustering_comprehensive_guide.ipynb). |
| 7 | + |
| 8 | +## Overview |
| 9 | + |
| 10 | +Clustering, or weight sharing, reduces the number of unique weight values in a model, leading to benefits for deployment. It first groups the weights of each layer into *N* clusters, then shares the cluster's centroid value for all the weights belonging to the cluster. |
| 11 | + |
| 12 | +This technique brings improvements via model compression. Future framework support can unlock memory footprint improvements that can make a crucial difference for deploying deep learning models on embedded systems with limited resources. |
| 13 | + |
| 14 | +We have experimented with clustering across vision and speech tasks. We've seen up to 5x improvements in model compression with minimal loss of accuracy, as demonstrated by the [results](#results) presented below. |
| 15 | + |
| 16 | +Please note that clustering will provide reduced benefits for convolution and dense layers that precede a batch normalization layer, as well as in combination with per-axis post-training quantization. |
| 17 | + |
| 18 | +### API compatibility matrix |
| 19 | + |
| 20 | +Users can apply clustering with the following APIs: |
| 21 | + |
| 22 | +* Model building: `tf.keras` with only Sequential and Functional models |
| 23 | +* TensorFlow versions: TF 1.x for versions 1.14+ and 2.x. |
| 24 | + * `tf.compat.v1` with a TF 2.X package and `tf.compat.v2` with a TF 1.X |
| 25 | + package are not supported. |
| 26 | +* TensorFlow execution mode: both graph and eager |
| 27 | + |
| 28 | +## Results |
| 29 | + |
| 30 | +### Image classification |
| 31 | + |
| 32 | +<table> |
| 33 | + <tr> |
| 34 | + <th rowspan="2">Model</th> |
| 35 | + <th colspan="2">Original</th> |
| 36 | + <th colspan="4">Clustered</th> |
| 37 | + </tr> |
| 38 | + <tr> |
| 39 | + <th>Top-1 accuracy (%)</th> |
| 40 | + <th>Size of compressed .tflite (MB)</th> |
| 41 | + <th>Configuration</th> |
| 42 | + <th># of clusters</th> |
| 43 | + <th>Top-1 accuracy (%)</th> |
| 44 | + <th>Size of compressed .tflite (MB)</th> |
| 45 | + </tr> |
| 46 | + <tr> |
| 47 | + <td rowspan="3">MobileNetV1</td> |
| 48 | + <td rowspan="3">71.02</td> |
| 49 | + <td rowspan="3">14.96</td> |
| 50 | + </tr> |
| 51 | + <tr> |
| 52 | + <td>Selective (last 3 Conv2D layers)</td> |
| 53 | + <td>256, 256, 32</td> |
| 54 | + <td>70.62</td> |
| 55 | + <td>8.42</td> |
| 56 | + </tr> |
| 57 | + <tr> |
| 58 | + <td>Full (all Conv2D layers)</td> |
| 59 | + <td>64</td> |
| 60 | + <td>66.07</td> |
| 61 | + <td>2.98</td> |
| 62 | + </tr> |
| 63 | + <tr> |
| 64 | + <td rowspan="3">MobileNetV2</td> |
| 65 | + <td rowspan="3">72.29</td> |
| 66 | + <td rowspan="3">12.90</td> |
| 67 | + </tr> |
| 68 | + <tr> |
| 69 | + <td>Selective (last 3 Conv2D layers)</td> |
| 70 | + <td>256, 256, 32</td> |
| 71 | + <td>72.31</td> |
| 72 | + <td>7.00</td> |
| 73 | + </tr> |
| 74 | + <tr> |
| 75 | + <td>Full (all Conv2D layers)</td> |
| 76 | + <td>32</td> |
| 77 | + <td>69.33</td> |
| 78 | + <td>2.60</td> |
| 79 | + </tr> |
| 80 | +</table> |
| 81 | + |
| 82 | +The models were trained and tested on ImageNet. |
| 83 | + |
| 84 | +### Keyword spotting |
| 85 | + |
| 86 | +<table> |
| 87 | + <tr> |
| 88 | + <th rowspan=2>Model</th> |
| 89 | + <th colspan=2>Original</th> |
| 90 | + <th colspan=4>Clustered</th> |
| 91 | + </tr> |
| 92 | + <tr> |
| 93 | + <th>Top-1 accuracy (%)</th> |
| 94 | + <th>Size of compressed .tflite (MB)</th> |
| 95 | + <th>Configuration</th> |
| 96 | + <th># of clusters</th> |
| 97 | + <th>Top-1 accuracy (%)</th> |
| 98 | + <th>Size of compressed .tflite (MB)</th> |
| 99 | + </tr> |
| 100 | + <tr> |
| 101 | + <td>DS-CNN-L</td> |
| 102 | + <td>95.03</td> |
| 103 | + <td>1.5</td> |
| 104 | + <td>Full</td> |
| 105 | + <td>32</td> |
| 106 | + <td>94.71</td> |
| 107 | + <td>0.3</td> |
| 108 | + </tr> |
| 109 | +</table> |
| 110 | + |
| 111 | +The models were trained and tested on SpeechCommands v0.02. |
| 112 | + |
| 113 | +NOTE: *Size of compressed .tflite* refers to the size of the zipped .tflite file obtained from the model from the following process: |
| 114 | +1. Serialize the Keras model into .h5 file |
| 115 | +2. Convert the .h5 file into .tflite using `TFLiteConverter.from_keras_model_file()` |
| 116 | +3. Compress the .tflite file into a zip |
| 117 | + |
| 118 | +## Examples |
| 119 | + |
| 120 | +In addition to the [Weight clustering in Keras example](clustering_example.ipynb.ipynb), see the following examples: |
| 121 | + |
| 122 | +* Cluster the weights of a CNN model trained on the MNIST handwritten digit classification dataset: |
| 123 | +[code](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py) |
| 124 | + |
| 125 | +The weight clustering implementation is based on the *Deep Compression: Compressing Deep Neural Networks With Pruning, Trained Quantization and Huffman Coding* [paper](https://arxiv.org/abs/1510.00149). See chapter 3, titled *Trained Quantization and Weight Sharing*. |
0 commit comments