You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
@@ -160,9 +160,9 @@ This is an API for a layer weight based compression algorithm.
160
160
161
161
First, we start from a pre-trained model which the model developer has. And then convert the pre-trained model to training phase model for compression fine-tuning training. During the convert to training phase model, We call `init_training_weights_repr`for each tensor that we want to compress which is specified from the `get_compressible_weights` method.
162
162
163
-
During the training phase, `fake_decompress` method is called for each training step. After fine-tuning training for compression is finished, we convert the training phase model to a compressed model. We only call the `compress` function once for each compressible tensor for converting.
163
+
During the training phase, `project_training_weights` method is called for each training step. After fine-tuning training for compression is finished, we convert the training phase model to a compressed model. We only call the `compress_training_weights` function once for each compressible tensor for converting.
164
164
165
-
Compressed model contains the `decompress` function in the graph. It’s possible to call the `decompress`for each inference step. To improve performance, we’ll cache the decompressed one depending on flags if we have enough space.
165
+
Compressed model contains the `decompress_weights` function in the graph. It’s possible to call the `decompress_weights`for each inference step. To improve performance, we’ll cache the decompressed one depending on flags if we have enough space.
166
166
167
167
```python
168
168
class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
@@ -205,7 +205,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
"""Define the operations to decompress a single weight’s compressed form during inference.
254
254
255
255
The default is an identity.
@@ -395,38 +395,38 @@ Now we'll explain when each method is called and how many that method called for
395
395
396
396
`init_training_weights_repr`is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
`fake_decompress`is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
407
+
`project_training_weights`is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
`compress`is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
418
+
`compress_training_weights`is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
`decompress`is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
429
+
`decompress_weights`is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
430
430
431
431
## Questions and Discussion Topics
432
432
@@ -443,3 +443,4 @@ Note that every trainable variable that they want to train should be in training
443
443
444
444
It's not easy to find the bug there. Usually we get tensorflow bug messages with huge stack traces. We have to provide some bug messages for this API layer.
0 commit comments