2424
2525@dataclasses .dataclass
2626class WeightRepr :
27- """Dataclass that wraps `tf.keras.layers.Layer.add_weight` parameters."""
28- name : Any = None
29- shape : Any = None
30- dtype : Any = None
31- initializer : Any = None
32- regularizer : Any = None
33- trainable : Any = None
34- constraint : Any = None
35- partitioner : Any = None
36- use_resource : Any = None
37- synchronization : Any = tf .VariableSynchronization .AUTO
38- aggregation : Any = tf .compat .v1 .VariableAggregation .NONE
39-
40-
41- class WeightCompressionAlgorithm (metaclass = abc .ABCMeta ):
27+ args : Any = None
28+ kwargs : Any = None
29+
30+
31+ class WeightCompressor (metaclass = abc .ABCMeta ):
4232 """Interface for weight compression algorithm that acts on a per-layer basis.
4333
4434 This allows both options of either decompressing during inference or
@@ -48,93 +38,127 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
4838 This interface is a purely functional one.
4939 """
5040
51- @ abc . abstractmethod
52- def init_training_weights_repr (
53- self , pretrained_weight : tf .Tensor ) -> List [WeightRepr ]:
54- """Create training weight representations for initializing layer variables .
41+ # TODO(tfmot): Consider separate from algorithm API for custom layer supports.
42+ def get_compressible_weights (
43+ self , original_layer : tf .keras . layers . Layer ) -> List [tf . Variable ]:
44+ """Define compressible weights for each layer.
5545
5646 Args:
57- pretrained_weight : tf.Tensor of a pretrained weight of a layer that will
58- be compressed eventually .
47+ original_layer : tf.keras.layers.Layer representing a layer from the
48+ original model .
5949
6050 Returns:
61- A list of `WeightRepr`, a container for arguments to
62- `tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
51+ List of compressible weights for the given layer.
6352 """
53+ del original_layer
54+ return []
6455
65- def compress (self , * training_weights : tf .Tensor ) -> List [tf .Tensor ]:
66- """Define the operations to compress a single weight after training.
67-
68- 'Compress' can refer to making the weight more amenable to compression
69- or actually compress the weight.
56+ @abc .abstractmethod
57+ def init_training_weights (
58+ self , pretrained_weight : tf .Tensor ):
59+ """Initialize training weights for the compressible weight.
7060
71- The default is an identity.
61+ It calls the `add_training_weight` to add a training weight for a given
62+ `pretrained_weight`. A `pretrained_weight` can have multiple training
63+ weights. We initialize the training weights for each compressible
64+ weight by just calling this function for each.
7265
7366 Args:
74- *training_weights: tf.Tensors representing all variables used during
75- training, for a single compressible weight, in the order returned in
76- `init_training_weights_repr`.
77-
78- Returns:
79- List of tf.Tensors to set to compressed or more compressible form.
67+ pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
68+ be compressed eventually.
8069 """
81- return list (training_weights )
8270
83- def decompress (self , * compressed_weights : tf .Tensor ) -> tf .Tensor :
84- """Define the operations to decompress a single weight’s compressed form during inference.
71+ def add_training_weight (
72+ self , * args , ** kwargs ):
73+ """Add a training weight for the compressible weight.
8574
86- The default is an identity. TODO(): actually isn't.
75+ When this method is called from the `init_training_weights`, this adds
76+ training weights for the pretrained_weight that is the input of the
77+ `init_training_weights`.
8778
8879 Args:
89- *compressed_weights: tf.Tensors representing a single weight’s compressed
90- form, coming from what’s returned in `compress`.
91-
92- Returns:
93- A tf.Tensor representing the decompressed `compressed_weights`.
80+ *args: Passed through to training_model.add_weight.
81+ **kwargs: Passed through to training_model.add_weight.
9482 """
95- return compressed_weights [0 ]
83+ weight_repr = WeightRepr (args = args , kwargs = kwargs )
84+ if hasattr (self , 'weight_reprs' ):
85+ self .weight_reprs .append (weight_repr )
86+ else :
87+ self .weight_reprs = [weight_repr ]
9688
9789 @abc .abstractmethod
98- def training (self , * training_weights : tf .Tensor ) -> tf .Tensor :
99- """Define a piece of the forward pass during training, which operates on a single compressible weight.
90+ def project_training_weights (
91+ self , * training_weights : tf .Tensor ) -> tf .Tensor :
92+ """Define a piece of the forward pass during training.
10093
101- TODO(tfmot): throw this error .
94+ It operates on a single compressible weight .
10295 The default throws an error when training occurs.
10396
10497 Args:
10598 *training_weights: tf.Tensors representing any variables used during
10699 training, for a single compressible weight, in the order returned in
107- `init_training_weights_repr `.
100+ `init_training_weights `.
108101
109102 Returns:
110103 tf.Tensor to set the compressible weight to.
111104 """
112105
113- # TODO(tfmot): Consider separate from algorithm API for custom layer supports.
114- def get_compressible_weights (
115- self , original_layer : tf .keras .layers .Layer ) -> List [str ]:
116- """Define compressible weights for each layer.
106+ def update_training_weight (
107+ self , training_weight : tf .Tensor , tensor : tf .Tensor ):
108+ """Update a training weight to a given tensor value.
109+
110+ This method is for the case that training weight should update to a specific
111+ value not from the model optimizer. It will throw an error if it can't
112+ find the training weight.
117113
118114 Args:
119- original_layer: tf.keras.layers.Layer representing a layer from the
120- original model.
115+ training_weight: tf.Tensor representing a training weight.
116+ tensor: tf.Tensor representing a value to be assigned to the training
117+ weight.
118+ """
119+
120+ def compress_training_weights (
121+ self , * training_weights : tf .Tensor ) -> List [tf .Tensor ]:
122+ """Define the operations to compress a single weight’s training form.
123+
124+ 'compress_training_weights' can refer to making the weight more amenable to
125+ compression or actually compress the weight.
126+
127+ The default is an identity.
128+
129+ Args:
130+ *training_weights: tf.Tensors representing all variables used during
131+ training, for a single compressible weight, in the order returned in
132+ `init_training_weights`.
121133
122134 Returns:
123- List of atrribute names as string representing list of compressible
124- weights for the given layer. (e.g. return value ['kernel'] means
125- layer.kernel is compressible.)
135+ List of tf.Tensors to set to compressed or more compressible form.
136+ """
137+ return list (training_weights )
138+
139+ @abc .abstractmethod
140+ def decompress_weights (
141+ self , * compressed_weights : tf .Tensor ) -> tf .Tensor :
142+ """Define the operations to decompress a single weight’s compressed form.
143+
144+ The default is an identity.
145+
146+ Args:
147+ *compressed_weights: tf.Tensors representing a single weight’s compressed
148+ form, coming from what’s returned in `compress`.
149+
150+ Returns:
151+ A tf.Tensor representing the decompressed `compressed_weights`.
126152 """
127- del original_layer
128- return []
129153
130154
131155def create_layer_for_training (
132156 layer : tf .keras .layers .Layer ,
133- algorithm : WeightCompressionAlgorithm ) -> tf .keras .layers .Layer :
157+ algorithm : WeightCompressor ) -> tf .keras .layers .Layer :
134158 return optimize .create_layer_for_training (layer , algorithm )
135159
136160
137161def create_layer_for_inference (
138162 layer_for_training : tf .keras .layers .Layer ,
139- algorithm : WeightCompressionAlgorithm ) -> tf .keras .layers .Layer :
163+ algorithm : WeightCompressor ) -> tf .keras .layers .Layer :
140164 return optimize .create_layer_for_inference (layer_for_training , algorithm )
0 commit comments