24
24
25
25
@dataclasses .dataclass
26
26
class WeightRepr :
27
- """Dataclass that wraps `tf.keras.layers.Layer.add_weight` parameters."""
28
- name : Any = None
29
- shape : Any = None
30
- dtype : Any = None
31
- initializer : Any = None
32
- regularizer : Any = None
33
- trainable : Any = None
34
- constraint : Any = None
35
- partitioner : Any = None
36
- use_resource : Any = None
37
- synchronization : Any = tf .VariableSynchronization .AUTO
38
- aggregation : Any = tf .compat .v1 .VariableAggregation .NONE
39
-
40
-
41
- class WeightCompressionAlgorithm (metaclass = abc .ABCMeta ):
27
+ args : Any = None
28
+ kwargs : Any = None
29
+
30
+
31
+ class WeightCompressor (metaclass = abc .ABCMeta ):
42
32
"""Interface for weight compression algorithm that acts on a per-layer basis.
43
33
44
34
This allows both options of either decompressing during inference or
@@ -48,93 +38,127 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
48
38
This interface is a purely functional one.
49
39
"""
50
40
51
- @ abc . abstractmethod
52
- def init_training_weights_repr (
53
- self , pretrained_weight : tf .Tensor ) -> List [WeightRepr ]:
54
- """Create training weight representations for initializing layer variables .
41
+ # TODO(tfmot): Consider separate from algorithm API for custom layer supports.
42
+ def get_compressible_weights (
43
+ self , original_layer : tf .keras . layers . Layer ) -> List [tf . Variable ]:
44
+ """Define compressible weights for each layer.
55
45
56
46
Args:
57
- pretrained_weight : tf.Tensor of a pretrained weight of a layer that will
58
- be compressed eventually .
47
+ original_layer : tf.keras.layers.Layer representing a layer from the
48
+ original model .
59
49
60
50
Returns:
61
- A list of `WeightRepr`, a container for arguments to
62
- `tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
51
+ List of compressible weights for the given layer.
63
52
"""
53
+ del original_layer
54
+ return []
64
55
65
- def compress (self , * training_weights : tf .Tensor ) -> List [tf .Tensor ]:
66
- """Define the operations to compress a single weight after training.
67
-
68
- 'Compress' can refer to making the weight more amenable to compression
69
- or actually compress the weight.
56
+ @abc .abstractmethod
57
+ def init_training_weights (
58
+ self , pretrained_weight : tf .Tensor ):
59
+ """Initialize training weights for the compressible weight.
70
60
71
- The default is an identity.
61
+ It calls the `add_training_weight` to add a training weight for a given
62
+ `pretrained_weight`. A `pretrained_weight` can have multiple training
63
+ weights. We initialize the training weights for each compressible
64
+ weight by just calling this function for each.
72
65
73
66
Args:
74
- *training_weights: tf.Tensors representing all variables used during
75
- training, for a single compressible weight, in the order returned in
76
- `init_training_weights_repr`.
77
-
78
- Returns:
79
- List of tf.Tensors to set to compressed or more compressible form.
67
+ pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
68
+ be compressed eventually.
80
69
"""
81
- return list (training_weights )
82
70
83
- def decompress (self , * compressed_weights : tf .Tensor ) -> tf .Tensor :
84
- """Define the operations to decompress a single weight’s compressed form during inference.
71
+ def add_training_weight (
72
+ self , * args , ** kwargs ):
73
+ """Add a training weight for the compressible weight.
85
74
86
- The default is an identity. TODO(): actually isn't.
75
+ When this method is called from the `init_training_weights`, this adds
76
+ training weights for the pretrained_weight that is the input of the
77
+ `init_training_weights`.
87
78
88
79
Args:
89
- *compressed_weights: tf.Tensors representing a single weight’s compressed
90
- form, coming from what’s returned in `compress`.
91
-
92
- Returns:
93
- A tf.Tensor representing the decompressed `compressed_weights`.
80
+ *args: Passed through to training_model.add_weight.
81
+ **kwargs: Passed through to training_model.add_weight.
94
82
"""
95
- return compressed_weights [0 ]
83
+ weight_repr = WeightRepr (args = args , kwargs = kwargs )
84
+ if hasattr (self , 'weight_reprs' ):
85
+ self .weight_reprs .append (weight_repr )
86
+ else :
87
+ self .weight_reprs = [weight_repr ]
96
88
97
89
@abc .abstractmethod
98
- def training (self , * training_weights : tf .Tensor ) -> tf .Tensor :
99
- """Define a piece of the forward pass during training, which operates on a single compressible weight.
90
+ def project_training_weights (
91
+ self , * training_weights : tf .Tensor ) -> tf .Tensor :
92
+ """Define a piece of the forward pass during training.
100
93
101
- TODO(tfmot): throw this error .
94
+ It operates on a single compressible weight .
102
95
The default throws an error when training occurs.
103
96
104
97
Args:
105
98
*training_weights: tf.Tensors representing any variables used during
106
99
training, for a single compressible weight, in the order returned in
107
- `init_training_weights_repr `.
100
+ `init_training_weights `.
108
101
109
102
Returns:
110
103
tf.Tensor to set the compressible weight to.
111
104
"""
112
105
113
- # TODO(tfmot): Consider separate from algorithm API for custom layer supports.
114
- def get_compressible_weights (
115
- self , original_layer : tf .keras .layers .Layer ) -> List [str ]:
116
- """Define compressible weights for each layer.
106
+ def update_training_weight (
107
+ self , training_weight : tf .Tensor , tensor : tf .Tensor ):
108
+ """Update a training weight to a given tensor value.
109
+
110
+ This method is for the case that training weight should update to a specific
111
+ value not from the model optimizer. It will throw an error if it can't
112
+ find the training weight.
117
113
118
114
Args:
119
- original_layer: tf.keras.layers.Layer representing a layer from the
120
- original model.
115
+ training_weight: tf.Tensor representing a training weight.
116
+ tensor: tf.Tensor representing a value to be assigned to the training
117
+ weight.
118
+ """
119
+
120
+ def compress_training_weights (
121
+ self , * training_weights : tf .Tensor ) -> List [tf .Tensor ]:
122
+ """Define the operations to compress a single weight’s training form.
123
+
124
+ 'compress_training_weights' can refer to making the weight more amenable to
125
+ compression or actually compress the weight.
126
+
127
+ The default is an identity.
128
+
129
+ Args:
130
+ *training_weights: tf.Tensors representing all variables used during
131
+ training, for a single compressible weight, in the order returned in
132
+ `init_training_weights`.
121
133
122
134
Returns:
123
- List of atrribute names as string representing list of compressible
124
- weights for the given layer. (e.g. return value ['kernel'] means
125
- layer.kernel is compressible.)
135
+ List of tf.Tensors to set to compressed or more compressible form.
136
+ """
137
+ return list (training_weights )
138
+
139
+ @abc .abstractmethod
140
+ def decompress_weights (
141
+ self , * compressed_weights : tf .Tensor ) -> tf .Tensor :
142
+ """Define the operations to decompress a single weight’s compressed form.
143
+
144
+ The default is an identity.
145
+
146
+ Args:
147
+ *compressed_weights: tf.Tensors representing a single weight’s compressed
148
+ form, coming from what’s returned in `compress`.
149
+
150
+ Returns:
151
+ A tf.Tensor representing the decompressed `compressed_weights`.
126
152
"""
127
- del original_layer
128
- return []
129
153
130
154
131
155
def create_layer_for_training (
132
156
layer : tf .keras .layers .Layer ,
133
- algorithm : WeightCompressionAlgorithm ) -> tf .keras .layers .Layer :
157
+ algorithm : WeightCompressor ) -> tf .keras .layers .Layer :
134
158
return optimize .create_layer_for_training (layer , algorithm )
135
159
136
160
137
161
def create_layer_for_inference (
138
162
layer_for_training : tf .keras .layers .Layer ,
139
- algorithm : WeightCompressionAlgorithm ) -> tf .keras .layers .Layer :
163
+ algorithm : WeightCompressor ) -> tf .keras .layers .Layer :
140
164
return optimize .create_layer_for_inference (layer_for_training , algorithm )
0 commit comments