You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
1. Change the API class name. (WeightCompressionAlgorithm-> WeightCompressor)
2. Change indentation. (Remove indent for Fenced Code Blocks)
3. Added more comments on `init_training_weights` method.
@@ -48,100 +48,100 @@ Our API also provides guidelines for testing and benchmark. For now, we only hav
48
48
### Tutorials and Examples
49
49
We provide the tutorial for [SVD](https://en.wikipedia.org/wiki/Singular_value_decomposition) compression algorithm that shows how we implement the SVD algorithm using TFMOT compression API by colab. This tutorial includes:
50
50
51
-
* Algorithm developer side.
52
-
1. The algorithm developer implementing the SVD algorithm uses the `WeightCompressionAlgorithm` class.
53
-
54
-
```python
55
-
classSVD(algorithm.WeightCompressionAlgorithm):
56
-
"""SVD compression module config."""
57
-
58
-
def__init__(self, params):
59
-
self.params = params
60
-
61
-
definit_training_weights(
62
-
self, pretrained_weight: tf.Tensor):
63
-
"""Init function from pre-trained model case."""
64
-
rank =self.params.rank
65
-
66
-
# Dense Layer
67
-
iflen(pretrained_weight.shape) ==2:
68
-
u, sv = tf_svd_factorization_2d(pretrained_weight, rank)
69
-
else:
70
-
raiseNotImplementedError('Only for dimension=2 is supported.')
We also want to provide an example of well-known compression algorithms. Here’s algorithm list at least we have to provide:
147
147
*[Weight clustering](https://arxiv.org/abs/1510.00149) : Most famous compression algorithm that can be used widely.
@@ -163,7 +163,7 @@ During the training phase, `project_training_weights` method is called for each
163
163
Compressed model contains the `decompress_weights` function in the graph. It’s possible to call the `decompress_weights` for each inference step. To improve performance, we’ll cache the decompressed one depending on flags if we have enough space.
"""Interface for weight compression algorithm that acts on a per-layer basis.
168
168
169
169
This allows both options of either decompressing during inference or
@@ -191,7 +191,12 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
191
191
@abc.abstractmethod
192
192
definit_training_weights(
193
193
self, pretrained_weight: tf.Tensor):
194
-
"""Initialize training weights for the training model. It calls the `add_training_weight` method several times to add training weights.
194
+
"""Initialize training weights for the compressible weight.
195
+
196
+
It calls the `add_training_weight` to add a training weight for a given
197
+
`pretrained_weight`. A `pretrained_weight` can have multiple training
198
+
weights. We initialize the training weights for each compressible
199
+
weight by just calling this function for each.
195
200
196
201
Args:
197
202
pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
@@ -200,7 +205,11 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
200
205
201
206
defadd_training_weight(
202
207
self, *args, **kwargs):
203
-
"""Add training weight for the training model. This method is called from `init_training_weights`.
208
+
"""Add a training weight for the compressible weight.
209
+
210
+
When this method is called from the `init_training_weights`, this adds
211
+
a training weights for the pretrained_weight that is the input of the
212
+
`init_training_weights`.
204
213
205
214
Args:
206
215
*args, **kwargs: args and kwargs for training_model.add_weight.
@@ -270,7 +279,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
270
279
#### Model compression algorithm API
271
280
Some compression algorithms require training weights or compressed weights that share the weights across the layer. (e.g. lookup table for weight clustering.)
272
281
We decided to support layer variable wise compression algorithm API first, because... :
273
-
* Most use cases can be covered by the WeightCompressionAlgorithmAPI.
282
+
* Most use cases can be covered by the `WeightCompressor` class API.
274
283
* Hard to support a sequential model: That weight across the layer should be placed somewhere outside of the sequential model.
275
284
276
285
### User Impact
@@ -297,7 +306,7 @@ We’ll provide examples of compression algorithms using the API in this design,
297
306
This API is a standalone project that only depends on tensorflow.
298
307
299
308
### Engineering Impact
300
-
TF-MOT team will maintain this API code. For the initial release, we publicize the WeightCompressionAlgorithmclass that the algorithm developers have to inherit this class to implement their own compression algorithm, WrapperLayer methods to access original layer, And model clone based default converter functions for model developer to help them implement their own algorithm specific APIs.
309
+
TF-MOT team will maintain this API code. For the initial release, we publicize the `WeightCompressor` class that the algorithm developers have to inherit this class to implement their own compression algorithm, WrapperLayer methods to access original layer, And model clone based default converter functions for model developer to help them implement their own algorithm specific APIs.
301
310
302
311
### Platforms and Environments
303
312
For initial release, we’ve targeted the TF 2.0 Keras model. After compressing the model, the compressed model can deploy to servers as TF model, mobile/embedded environments as TFLite model, and web as tf.js format.
@@ -313,9 +322,9 @@ Compressed models can be converted to TF model, TFLite model, and tf.js format.
313
322
This is an API design doc. Engineering details will be determined in the future.
314
323
For better explanation of this API, Here's the step-by-step usage documentation below:
315
324
316
-
### Step-by-step usage documentation of the WeightCompressionAlgorithm class methods.
325
+
### Step-by-step usage documentation of the `WeightCompressor` class methods.
317
326
318
-
The WeightCompressionAlgorithmclass has 5 abstract methods. Following explanation shows when these methods are called and used.
327
+
The `WeightCompressor` class has 5 abstract methods. Following explanation shows when these methods are called and used.
`get_compressible_weights`is called when we want to get a list of variables that we will apply compression.
387
-
When we try to compress the pre-trained model, we just call this method for each layer in the pre-trained model. The number of the method calling is (# of layers).
395
+
`get_compressible_weights` is called when we want to get a list of variables that we will apply compression.
396
+
When we try to compress the pre-trained model, we just call this method for each layer in the pre-trained model. The number of the method calling is (# of layers).
`init_training_weights`is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
407
+
`init_training_weights` is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
`project_training_weights`is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
418
+
`project_training_weights` is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
`compress_training_weights`is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
429
+
`compress_training_weights` is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
`decompress_weights`is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
440
+
`decompress_weights` is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
432
441
433
442
## Questions and Discussion Topics
434
443
@@ -444,4 +453,3 @@ Note that every trainable variable that they want to train should be in training
444
453
### Error message & Debugging tools.
445
454
446
455
It's not easy to find the bug there. Usually we get tensorflow bug messages with huge stack traces. We have to provide some bug messages for this API layer.
0 commit comments