Skip to content

Commit 226ca71

Browse files
Merge pull request #301 from arovir01:toupstream/clustering_guide
PiperOrigin-RevId: 318281266
2 parents 888621b + 874120b commit 226ca71

File tree

1 file changed

+125
-0
lines changed
  • tensorflow_model_optimization/g3doc/guide/clustering

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Weight clustering
2+
3+
This document provides an overview on weight clustering to help you determine how it fits with your use case.
4+
5+
- To dive right into an end-to-end example, see the [weight clustering example](clustering_example.ipynb).
6+
- To quickly find the APIs you need for your use case, see the [weight clustering comprehensive guide](clustering_comprehensive_guide.ipynb).
7+
8+
## Overview
9+
10+
Clustering, or weight sharing, reduces the number of unique weight values in a model, leading to benefits for deployment. It first groups the weights of each layer into *N* clusters, then shares the cluster's centroid value for all the weights belonging to the cluster.
11+
12+
This technique brings improvements via model compression. Future framework support can unlock memory footprint improvements that can make a crucial difference for deploying deep learning models on embedded systems with limited resources.
13+
14+
We have experimented with clustering across vision and speech tasks. We've seen up to 5x improvements in model compression with minimal loss of accuracy, as demonstrated by the [results](#results) presented below.
15+
16+
Please note that clustering will provide reduced benefits for convolution and dense layers that precede a batch normalization layer, as well as in combination with per-axis post-training quantization.
17+
18+
### API compatibility matrix
19+
20+
Users can apply clustering with the following APIs:
21+
22+
* Model building: `tf.keras` with only Sequential and Functional models
23+
* TensorFlow versions: TF 1.x for versions 1.14+ and 2.x.
24+
* `tf.compat.v1` with a TF 2.X package and `tf.compat.v2` with a TF 1.X
25+
package are not supported.
26+
* TensorFlow execution mode: both graph and eager
27+
28+
## Results
29+
30+
### Image classification
31+
32+
<table>
33+
<tr>
34+
<th rowspan="2">Model</th>
35+
<th colspan="2">Original</th>
36+
<th colspan="4">Clustered</th>
37+
</tr>
38+
<tr>
39+
<th>Top-1 accuracy (%)</th>
40+
<th>Size of compressed .tflite (MB)</th>
41+
<th>Configuration</th>
42+
<th># of clusters</th>
43+
<th>Top-1 accuracy (%)</th>
44+
<th>Size of compressed .tflite (MB)</th>
45+
</tr>
46+
<tr>
47+
<td rowspan="3">MobileNetV1</td>
48+
<td rowspan="3">71.02</td>
49+
<td rowspan="3">14.96</td>
50+
</tr>
51+
<tr>
52+
<td>Selective (last 3 Conv2D layers)</td>
53+
<td>256, 256, 32</td>
54+
<td>70.62</td>
55+
<td>8.42</td>
56+
</tr>
57+
<tr>
58+
<td>Full (all Conv2D layers)</td>
59+
<td>64</td>
60+
<td>66.07</td>
61+
<td>2.98</td>
62+
</tr>
63+
<tr>
64+
<td rowspan="3">MobileNetV2</td>
65+
<td rowspan="3">72.29</td>
66+
<td rowspan="3">12.90</td>
67+
</tr>
68+
<tr>
69+
<td>Selective (last 3 Conv2D layers)</td>
70+
<td>256, 256, 32</td>
71+
<td>72.31</td>
72+
<td>7.00</td>
73+
</tr>
74+
<tr>
75+
<td>Full (all Conv2D layers)</td>
76+
<td>32</td>
77+
<td>69.33</td>
78+
<td>2.60</td>
79+
</tr>
80+
</table>
81+
82+
The models were trained and tested on ImageNet.
83+
84+
### Keyword spotting
85+
86+
<table>
87+
<tr>
88+
<th rowspan=2>Model</th>
89+
<th colspan=2>Original</th>
90+
<th colspan=4>Clustered</th>
91+
</tr>
92+
<tr>
93+
<th>Top-1 accuracy (%)</th>
94+
<th>Size of compressed .tflite (MB)</th>
95+
<th>Configuration</th>
96+
<th># of clusters</th>
97+
<th>Top-1 accuracy (%)</th>
98+
<th>Size of compressed .tflite (MB)</th>
99+
</tr>
100+
<tr>
101+
<td>DS-CNN-L</td>
102+
<td>95.03</td>
103+
<td>1.5</td>
104+
<td>Full</td>
105+
<td>32</td>
106+
<td>94.71</td>
107+
<td>0.3</td>
108+
</tr>
109+
</table>
110+
111+
The models were trained and tested on SpeechCommands v0.02.
112+
113+
NOTE: *Size of compressed .tflite* refers to the size of the zipped .tflite file obtained from the model from the following process:
114+
1. Serialize the Keras model into .h5 file
115+
2. Convert the .h5 file into .tflite using `TFLiteConverter.from_keras_model_file()`
116+
3. Compress the .tflite file into a zip
117+
118+
## Examples
119+
120+
In addition to the [Weight clustering in Keras example](clustering_example.ipynb.ipynb), see the following examples:
121+
122+
* Cluster the weights of a CNN model trained on the MNIST handwritten digit classification dataset:
123+
[code](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py)
124+
125+
The weight clustering implementation is based on the *Deep Compression: Compressing Deep Neural Networks With Pruning, Trained Quantization and Huffman Coding* [paper](https://arxiv.org/abs/1510.00149). See chapter 3, titled *Trained Quantization and Weight Sharing*.

0 commit comments

Comments
 (0)