Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 248f918

Browse files
committed
Update tfmot compression api RFC: change method naming & way to initialize the training weights.
1 parent e286275 commit 248f918

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

rfcs/20201221-tfmot-compression-api.md

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ We provide the tutorial for [SVD](https://en.wikipedia.org/wiki/Singular_value_d
5858
def __init__(self, params):
5959
self.params = params
6060

61-
def init_training_weights_repr(
62-
self, pretrained_weight: tf.Tensor) -> List[algorithm.WeightRepr]:
61+
def init_training_weights(
62+
self, pretrained_weight: tf.Tensor):
6363
"""Init function from pre-trained model case."""
6464
rank = self.params.rank
6565

@@ -69,18 +69,16 @@ We provide the tutorial for [SVD](https://en.wikipedia.org/wiki/Singular_value_d
6969
else:
7070
raise NotImplementedError('Only for dimension=2 is supported.')
7171

72-
return [
73-
algorithm.WeightRepr(
74-
name='u',
75-
shape=u.shape,
76-
dtype=u.dtype,
77-
initializer=tf.keras.initializers.Constant(u)),
78-
algorithm.WeightRepr(
79-
name='sv',
80-
shape=sv.shape,
81-
dtype=sv.dtype,
82-
initializer=tf.keras.initializers.Constant(sv))
83-
]
72+
self.add_training_weight(
73+
name='u',
74+
shape=u.shape,
75+
dtype=u.dtype,
76+
initializer=tf.keras.initializers.Constant(u)),
77+
self.add_training_weight(
78+
name='sv',
79+
shape=sv.shape,
80+
dtype=sv.dtype,
81+
initializer=tf.keras.initializers.Constant(sv))
8482

8583
def project_training_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
8684
return tf.matmul(u, sv)
@@ -158,7 +156,7 @@ We also want to provide an example of well-known compression algorithms. Here’
158156

159157
This is an API for a layer weight based compression algorithm.
160158

161-
First, we start from a pre-trained model which the model developer has. And then convert the pre-trained model to training phase model for compression fine-tuning training. During the convert to training phase model, We call `init_training_weights_repr` for each tensor that we want to compress which is specified from the `get_compressible_weights` method.
159+
First, we start from a pre-trained model which the model developer has. And then convert the pre-trained model to training phase model for compression fine-tuning training. During the convert to training phase model, We call `init_training_weights` for each tensor that we want to compress which is specified from the `get_compressible_weights` method.
162160

163161
During the training phase, `project_training_weights` method is called for each training step. After fine-tuning training for compression is finished, we convert the training phase model to a compressed model. We only call the `compress_training_weights` function once for each compressible tensor for converting.
164162

@@ -191,17 +189,21 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
191189
"""
192190

193191
@abc.abstractmethod
194-
def init_training_weights_repr(
195-
self, pretrained_weight: tf.Tensor) -> List[WeightRepr]:
196-
"""Create training weight representations for initializing layer variables.
192+
def init_training_weights(
193+
self, pretrained_weight: tf.Tensor):
194+
"""Initialize training weights for the training model. It calls the `add_training_weight` method several times to add training weights.
197195
198196
Args:
199197
pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
200198
be compressed eventually.
199+
"""
201200

202-
Returns:
203-
A list of `WeightRepr`, a container for arguments to
204-
`tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
201+
def add_training_weight(
202+
self, *args, **kwargs):
203+
"""Add training weight for the training model. This method is called from `init_training_weights`.
204+
205+
Args:
206+
*args, **kwargs: args and kwargs for training_model.add_weight.
205207
"""
206208

207209
@abc.abstractmethod
@@ -212,7 +214,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
212214
Args:
213215
*training_weights: tf.Tensors representing any variables used during
214216
training, for a single compressible weight, in the order returned in
215-
`init_training_weights_repr`.
217+
`init_training_weights`.
216218
217219
Returns:
218220
tf.Tensor to set the compressible weight to.
@@ -242,7 +244,7 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
242244
Args:
243245
*training_weights: tf.Tensors representing all variables used during
244246
training, for a single compressible weight, in the order returned in
245-
`init_training_weights_repr`.
247+
`init_training_weights`.
246248
247249
Returns:
248250
List of tf.Tensors to set to compressed or more compressible form.
@@ -384,16 +386,16 @@ Now we'll explain when each method is called and how many that method called for
384386
`get_compressible_weights` is called when we want to get a list of variables that we will apply compression.
385387
When we try to compress the pre-trained model, we just call this method for each layer in the pre-trained model. The number of the method calling is (# of layers).
386388

387-
1. `init_training_weights_repr`
389+
1. `init_training_weights`
388390
<p align="center">
389-
<img src=20201221-tfmot-compression-api/init_training_weights_repr.png />
391+
<img src=20201221-tfmot-compression-api/init_training_weights.png />
390392
</p>
391393

392394
```python
393395
training_model = optimize_training(model, params)
394396
```
395397

396-
`init_training_weights_repr` is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
398+
`init_training_weights` is called when we initialize the cloned training model from the pre-trained model. `optimize_training` method basically clones the model to create a training model for compression, wrapping compressible layers by the training wrapper to create training weights. The number of the method calling is (# of compressible weights).
397399

398400
1. `project_training_weights`
399401
<p align="center">
@@ -443,4 +445,3 @@ Note that every trainable variable that they want to train should be in training
443445

444446
It's not easy to find the bug there. Usually we get tensorflow bug messages with huge stack traces. We have to provide some bug messages for this API layer.
445447

446-
40.8 KB
Loading
-40.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)