Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit e286275

Browse files
committed
Update tfmot compression api RFC: change method naming.
1 parent e68882d commit e286275

File tree

8 files changed

+19
-18
lines changed

8 files changed

+19
-18
lines changed

rfcs/20201221-tfmot-compression-api.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ We provide the tutorial for [SVD](https://en.wikipedia.org/wiki/Singular_value_d
8282
initializer=tf.keras.initializers.Constant(sv))
8383
]
8484

85-
def fake_decompress(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
85+
def project_training_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
8686
return tf.matmul(u, sv)
8787

8888
def get_compressible_weights(
@@ -160,9 +160,9 @@ This is an API for a layer weight based compression algorithm.
160160

161161
First, we start from a pre-trained model which the model developer has. And then convert the pre-trained model to training phase model for compression fine-tuning training. During the convert to training phase model, We call `init_training_weights_repr` for each tensor that we want to compress which is specified from the `get_compressible_weights` method.
162162

163-
During the training phase, `fake_decompress` method is called for each training step. After fine-tuning training for compression is finished, we convert the training phase model to a compressed model. We only call the `compress` function once for each compressible tensor for converting.
163+
During the training phase, `project_training_weights` method is called for each training step. After fine-tuning training for compression is finished, we convert the training phase model to a compressed model. We only call the `compress_training_weights` function once for each compressible tensor for converting.
164164

165-
Compressed model contains the `decompress` function in the graph. It’s possible to call the `decompress` for each inference step. To improve performance, we’ll cache the decompressed one depending on flags if we have enough space.
165+
Compressed model contains the `decompress_weights` function in the graph. It’s possible to call the `decompress_weights` for each inference step. To improve performance, we’ll cache the decompressed one depending on flags if we have enough space.
166166

167167
```python
168168
class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
@@ -205,7 +205,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
205205
"""
206206

207207
@abc.abstractmethod
208-
def fake_decompress(self, *training_weights: tf.Tensor) -> tf.Tensor:
208+
def project_training_weights(self, *training_weights: tf.Tensor) -> tf.Tensor:
209209
"""Define a piece of the forward pass during training, which operates on a single compressible weight.
210210
The default throws an error when training occurs.
211211
@@ -218,23 +218,23 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
218218
tf.Tensor to set the compressible weight to.
219219
"""
220220

221-
def update_training_weights(self, index, tensor: tf.Tensor):
221+
def update_training_weight(self, index: integer, tensor: tf.Tensor):
222222
"""Update a training weight on an index to a given tensor value.
223223
224224
This method is for the case that training weight should update to specific
225225
value not from the model optimizer. It'll throws an error if it can't
226226
find the training weight.
227227
228228
Args:
229-
index: integer indicate index of training weight to update.
229+
index: integer indicates index of training weight to update.
230230
tensor: tf.Tensor to update specific training weight.
231231
"""
232232

233233
@abc.abstractmethod
234-
def compress(self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
234+
def compress_training_weights(self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
235235
"""Define the operations to compress a single weight’s training form after training.
236236
237-
'Compress' can refer to making the weight more amenable to compression
237+
'compress_training_weights' can refer to making the weight more amenable to compression
238238
or actually compress the weight.
239239
240240
The default is an identity.
@@ -249,7 +249,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
249249
"""
250250

251251
@abc.abstractmethod
252-
def decompress(self, *compressed_weights: tf.Tensor) -> tf.Tensor:
252+
def decompress_weights(self, *compressed_weights: tf.Tensor) -> tf.Tensor:
253253
"""Define the operations to decompress a single weight’s compressed form during inference.
254254
255255
The default is an identity.
@@ -395,38 +395,38 @@ Now we'll explain when each method is called and how many that method called for
395395

396396
`init_training_weights_repr` is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
397397

398-
1. `fake_decompress`
398+
1. `project_training_weights`
399399
<p align="center">
400-
<img src=20201221-tfmot-compression-api/fake_decompress.png />
400+
<img src=20201221-tfmot-compression-api/project_training_weights.png />
401401
</p>
402402

403403
```python
404404
training_model.fit(x_train, y_train, epochs=2)
405405
```
406406

407-
`fake_decompress` is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
407+
`project_training_weights` is called when the training model for the compression algorithm is training. Usually this method function is a part of the training model. It recovers the original weight from the training weights, and should be differentiable. This method enables you to use the original graph to compute the model output, but train the training weights of the training model. For each training step, this method is called for every compressible weight. The number of the method calling is (# of compressible weights) * (training steps).
408408

409-
1. `compress`
409+
1. `compress_training_weights`
410410
<p align="center">
411-
<img src=20201221-tfmot-compression-api/compress.png />
411+
<img src=20201221-tfmot-compression-api/compress_training_weights.png />
412412
</p>
413413

414414
```python
415415
compressed_model = optimize_inference(training_model, params)
416416
```
417417

418-
`compress` is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
418+
`compress_training_weights` is called when we convert the training model to the compressed model. The number of the method calling is (# of compressible weights).
419419

420-
1. `decompress`
420+
1. `decompress_weights`
421421
<p align="center">
422-
<img src=20201221-tfmot-compression-api/decompress.png />
422+
<img src=20201221-tfmot-compression-api/decompress_weights.png />
423423
</p>
424424

425425
```python
426426
compressed_model.evaluate(x_test, y_test, verbose=2)
427427
```
428428

429-
`decompress` is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
429+
`decompress_weights` is called when we do inference on a compressed model. Usually this method function is a part of a compressed model. This method decompresses the weight that can be used on the original graph for each compressible weight. Basically the number of this method called is (# of compressible weights) * (# of inference). To improve performance, the output value of this method can be cached.
430430

431431
## Questions and Discussion Topics
432432

@@ -443,3 +443,4 @@ Note that every trainable variable that they want to train should be in training
443443

444444
It's not easy to find the bug there. Usually we get tensorflow bug messages with huge stack traces. We have to provide some bug messages for this API layer.
445445

446+
310 Bytes
Loading
-39.4 KB
Binary file not shown.
42 KB
Loading
-32.1 KB
Binary file not shown.
31.2 KB
Loading
-30.9 KB
Binary file not shown.
30.6 KB
Loading

0 commit comments

Comments
 (0)