Skip to content

Commit 3e7592f

Browse files
Xharktensorflower-gardener
authored andcommitted
Update compression API core to match with the RFC (tensorflow/community#342).
PiperOrigin-RevId: 361948873
1 parent 392b0b3 commit 3e7592f

File tree

10 files changed

+363
-315
lines changed

10 files changed

+363
-315
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py

Lines changed: 86 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,11 @@
2424

2525
@dataclasses.dataclass
2626
class WeightRepr:
27-
"""Dataclass that wraps `tf.keras.layers.Layer.add_weight` parameters."""
28-
name: Any = None
29-
shape: Any = None
30-
dtype: Any = None
31-
initializer: Any = None
32-
regularizer: Any = None
33-
trainable: Any = None
34-
constraint: Any = None
35-
partitioner: Any = None
36-
use_resource: Any = None
37-
synchronization: Any = tf.VariableSynchronization.AUTO
38-
aggregation: Any = tf.compat.v1.VariableAggregation.NONE
39-
40-
41-
class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
27+
args: Any = None
28+
kwargs: Any = None
29+
30+
31+
class WeightCompressor(metaclass=abc.ABCMeta):
4232
"""Interface for weight compression algorithm that acts on a per-layer basis.
4333
4434
This allows both options of either decompressing during inference or
@@ -48,93 +38,127 @@ class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
4838
This interface is a purely functional one.
4939
"""
5040

51-
@abc.abstractmethod
52-
def init_training_weights_repr(
53-
self, pretrained_weight: tf.Tensor) -> List[WeightRepr]:
54-
"""Create training weight representations for initializing layer variables.
41+
# TODO(tfmot): Consider separate from algorithm API for custom layer supports.
42+
def get_compressible_weights(
43+
self, original_layer: tf.keras.layers.Layer) -> List[tf.Variable]:
44+
"""Define compressible weights for each layer.
5545
5646
Args:
57-
pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
58-
be compressed eventually.
47+
original_layer: tf.keras.layers.Layer representing a layer from the
48+
original model.
5949
6050
Returns:
61-
A list of `WeightRepr`, a container for arguments to
62-
`tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
51+
List of compressible weights for the given layer.
6352
"""
53+
del original_layer
54+
return []
6455

65-
def compress(self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
66-
"""Define the operations to compress a single weight after training.
67-
68-
'Compress' can refer to making the weight more amenable to compression
69-
or actually compress the weight.
56+
@abc.abstractmethod
57+
def init_training_weights(
58+
self, pretrained_weight: tf.Tensor):
59+
"""Initialize training weights for the compressible weight.
7060
71-
The default is an identity.
61+
It calls the `add_training_weight` to add a training weight for a given
62+
`pretrained_weight`. A `pretrained_weight` can have multiple training
63+
weights. We initialize the training weights for each compressible
64+
weight by just calling this function for each.
7265
7366
Args:
74-
*training_weights: tf.Tensors representing all variables used during
75-
training, for a single compressible weight, in the order returned in
76-
`init_training_weights_repr`.
77-
78-
Returns:
79-
List of tf.Tensors to set to compressed or more compressible form.
67+
pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
68+
be compressed eventually.
8069
"""
81-
return list(training_weights)
8270

83-
def decompress(self, *compressed_weights: tf.Tensor) -> tf.Tensor:
84-
"""Define the operations to decompress a single weight’s compressed form during inference.
71+
def add_training_weight(
72+
self, *args, **kwargs):
73+
"""Add a training weight for the compressible weight.
8574
86-
The default is an identity. TODO(): actually isn't.
75+
When this method is called from the `init_training_weights`, this adds
76+
training weights for the pretrained_weight that is the input of the
77+
`init_training_weights`.
8778
8879
Args:
89-
*compressed_weights: tf.Tensors representing a single weight’s compressed
90-
form, coming from what’s returned in `compress`.
91-
92-
Returns:
93-
A tf.Tensor representing the decompressed `compressed_weights`.
80+
*args: Passed through to training_model.add_weight.
81+
**kwargs: Passed through to training_model.add_weight.
9482
"""
95-
return compressed_weights[0]
83+
weight_repr = WeightRepr(args=args, kwargs=kwargs)
84+
if hasattr(self, 'weight_reprs'):
85+
self.weight_reprs.append(weight_repr)
86+
else:
87+
self.weight_reprs = [weight_repr]
9688

9789
@abc.abstractmethod
98-
def training(self, *training_weights: tf.Tensor) -> tf.Tensor:
99-
"""Define a piece of the forward pass during training, which operates on a single compressible weight.
90+
def project_training_weights(
91+
self, *training_weights: tf.Tensor) -> tf.Tensor:
92+
"""Define a piece of the forward pass during training.
10093
101-
TODO(tfmot): throw this error.
94+
It operates on a single compressible weight.
10295
The default throws an error when training occurs.
10396
10497
Args:
10598
*training_weights: tf.Tensors representing any variables used during
10699
training, for a single compressible weight, in the order returned in
107-
`init_training_weights_repr`.
100+
`init_training_weights`.
108101
109102
Returns:
110103
tf.Tensor to set the compressible weight to.
111104
"""
112105

113-
# TODO(tfmot): Consider separate from algorithm API for custom layer supports.
114-
def get_compressible_weights(
115-
self, original_layer: tf.keras.layers.Layer) -> List[str]:
116-
"""Define compressible weights for each layer.
106+
def update_training_weight(
107+
self, training_weight: tf.Tensor, tensor: tf.Tensor):
108+
"""Update a training weight to a given tensor value.
109+
110+
This method is for the case that training weight should update to a specific
111+
value not from the model optimizer. It will throw an error if it can't
112+
find the training weight.
117113
118114
Args:
119-
original_layer: tf.keras.layers.Layer representing a layer from the
120-
original model.
115+
training_weight: tf.Tensor representing a training weight.
116+
tensor: tf.Tensor representing a value to be assigned to the training
117+
weight.
118+
"""
119+
120+
def compress_training_weights(
121+
self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
122+
"""Define the operations to compress a single weight’s training form.
123+
124+
'compress_training_weights' can refer to making the weight more amenable to
125+
compression or actually compress the weight.
126+
127+
The default is an identity.
128+
129+
Args:
130+
*training_weights: tf.Tensors representing all variables used during
131+
training, for a single compressible weight, in the order returned in
132+
`init_training_weights`.
121133
122134
Returns:
123-
List of atrribute names as string representing list of compressible
124-
weights for the given layer. (e.g. return value ['kernel'] means
125-
layer.kernel is compressible.)
135+
List of tf.Tensors to set to compressed or more compressible form.
136+
"""
137+
return list(training_weights)
138+
139+
@abc.abstractmethod
140+
def decompress_weights(
141+
self, *compressed_weights: tf.Tensor) -> tf.Tensor:
142+
"""Define the operations to decompress a single weight’s compressed form.
143+
144+
The default is an identity.
145+
146+
Args:
147+
*compressed_weights: tf.Tensors representing a single weight’s compressed
148+
form, coming from what’s returned in `compress`.
149+
150+
Returns:
151+
A tf.Tensor representing the decompressed `compressed_weights`.
126152
"""
127-
del original_layer
128-
return []
129153

130154

131155
def create_layer_for_training(
132156
layer: tf.keras.layers.Layer,
133-
algorithm: WeightCompressionAlgorithm) -> tf.keras.layers.Layer:
157+
algorithm: WeightCompressor) -> tf.keras.layers.Layer:
134158
return optimize.create_layer_for_training(layer, algorithm)
135159

136160

137161
def create_layer_for_inference(
138162
layer_for_training: tf.keras.layers.Layer,
139-
algorithm: WeightCompressionAlgorithm) -> tf.keras.layers.Layer:
163+
algorithm: WeightCompressor) -> tf.keras.layers.Layer:
140164
return optimize.create_layer_for_inference(layer_for_training, algorithm)

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,58 +23,64 @@
2323
# TODO(tfmot): This algorithm is showcase for bias only compression. if we find
2424
# better algorithm that can show better compressible weights coverage, then
2525
# we can remove this algorithm.
26-
class BiasOnly(algorithm.WeightCompressionAlgorithm):
26+
class BiasOnly(algorithm.WeightCompressor):
2727
"""Define how to apply BiasOnly algorithm."""
2828

2929
# TODO(tfmot): communicate that `pretrained_weight` will sometimes
3030
# be a dummy tensor and sometimes be actual pretrained values during
3131
# its actual usage.
32-
def init_training_weights_repr(
33-
self, pretrained_weight: tf.Tensor) -> List[algorithm.WeightRepr]:
32+
def init_training_weights(
33+
self, pretrained_weight: tf.Tensor):
3434
bias_mean = tf.reduce_mean(pretrained_weight)
3535
bias_shape = tf.shape(pretrained_weight)
3636

3737
# TODO(tfmot): note that it does not suffice to just have the initializer
3838
# to derive the shape from, in the case of a constant initializer.
3939
# The unit test fail without providing the shape.
40-
return [
41-
algorithm.WeightRepr(
42-
name='bias_mean',
43-
shape=(),
44-
initializer=tf.keras.initializers.Constant(bias_mean)),
45-
algorithm.WeightRepr(
46-
name='bias_shape',
47-
shape=bias_shape.shape,
48-
dtype=bias_shape.dtype,
49-
initializer=tf.keras.initializers.Constant(bias_shape))
50-
]
40+
self.add_training_weight(
41+
name='bias_mean',
42+
shape=bias_mean.shape,
43+
dtype=bias_mean.dtype,
44+
initializer=tf.keras.initializers.Constant(bias_mean))
45+
self.add_training_weight(
46+
name='bias_shape',
47+
shape=bias_shape.shape,
48+
dtype=bias_shape.dtype,
49+
initializer=tf.keras.initializers.Constant(bias_shape))
5150

52-
def decompress(
51+
def decompress_weights(
5352
self, bias_mean: tf.Tensor, bias_shape: tf.Tensor) -> tf.Tensor:
5453
return tf.broadcast_to(bias_mean, bias_shape)
5554

56-
def training(
55+
def project_training_weights(
5756
self, bias_mean: tf.Tensor, bias_shape: tf.Tensor) -> tf.Tensor:
58-
return self.decompress(bias_mean, bias_shape)
57+
return self.decompress_weights(bias_mean, bias_shape)
5958

6059
def get_compressible_weights(
6160
self, original_layer: tf.keras.layers.Layer) -> List[str]:
6261
if isinstance(original_layer, tf.keras.layers.Conv2D) or \
6362
isinstance(original_layer, tf.keras.layers.Dense):
64-
return ['bias']
63+
return [original_layer.bias]
6564
return []
6665

67-
68-
def optimize(to_optimize: tf.keras.Model) -> tf.keras.Model:
69-
"""Model developer API for optimizing a model."""
70-
71-
def _optimize_layer(layer):
72-
# Require layer to be built so that the average of bias can be initialized.
73-
if not layer.built:
66+
def compress_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
67+
"""Model developer API for optimizing a model."""
68+
# pylint: disable=protected-access
69+
if not isinstance(to_optimize, tf.keras.Sequential) \
70+
and not to_optimize._is_graph_network:
7471
raise ValueError(
75-
'Applying BiasOnly currently requires passing in a built model')
72+
'`compress_model` can only either be a tf.keras Sequential or '
73+
'Functional model.')
74+
# pylint: enable=protected-access
75+
76+
def _optimize_layer(layer):
77+
# Require layer to be built so that the average of bias can be
78+
# initialized.
79+
if not layer.built:
80+
raise ValueError(
81+
'Applying BiasOnly currently requires passing in a built model')
7682

77-
return algorithm.create_layer_for_training(layer, algorithm=BiasOnly())
83+
return algorithm.create_layer_for_training(layer, algorithm=self)
7884

79-
return tf.keras.models.clone_model(
80-
to_optimize, clone_function=_optimize_layer)
85+
return tf.keras.models.clone_model(
86+
to_optimize, clone_function=_optimize_layer)

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ class FunctionalTest(tf.test.TestCase):
136136

137137
def testBiasOnly_ReducesParamaters(self):
138138
model = _build_model()
139-
compressed_model = bias_only.optimize(model)
139+
compressed_model = bias_only.BiasOnly().compress_model(model)
140140

141141
self.assertEqual(model.count_params(), 431080)
142142
self.assertEqual(compressed_model.count_params(), 430508)
143143

144144
def testBiasOnly_HasReasonableAccuracy_TF(self):
145145
model = _build_model()
146146

147-
compressed_model = bias_only.optimize(model)
147+
compressed_model = bias_only.BiasOnly().compress_model(model)
148148

149149
_train_model(compressed_model)
150150

@@ -162,7 +162,7 @@ def testBiasOnly_HasReasonableAccuracy_TF(self):
162162
def testBiasOnly_HasReasonableAccuracy_TFLite(self):
163163
model = _build_model()
164164

165-
compressed_model = bias_only.optimize(model)
165+
compressed_model = bias_only.BiasOnly().compress_model(model)
166166

167167
_train_model(compressed_model)
168168

@@ -180,7 +180,7 @@ def testBiasOnly_BreaksDownLayerWeights(self):
180180
first_conv_layer = model.layers[2]
181181
self.assertLen(first_conv_layer.weights, 2)
182182

183-
compressed_model = bias_only.optimize(model)
183+
compressed_model = bias_only.BiasOnly().compress_model(model)
184184

185185
first_conv_layer = compressed_model.layers[2]
186186

@@ -194,20 +194,23 @@ def testBiasOnly_PreservesPretrainedWeights(self):
194194

195195
dense_layer_weights = model.layers[1].get_weights()
196196

197-
compressed_model = bias_only.optimize(model)
197+
algorithm = bias_only.BiasOnly()
198+
compressed_model = algorithm.compress_model(model)
198199

199200
dense_layer_compressed_weights = compressed_model.layers[1].get_weights()
200201

201202
# kernel
202203
assert (dense_layer_weights[0] == dense_layer_compressed_weights[2]).all()
203204

204205
# bias
205-
algorithm = bias_only.BiasOnly()
206-
w1_repr, w2_repr = algorithm.init_training_weights_repr(
207-
dense_layer_weights[1])
208-
209-
w1 = w1_repr.initializer(shape=None, dtype=w1_repr.dtype)
210-
w2 = w2_repr.initializer(shape=None, dtype=w2_repr.dtype)
206+
algorithm.weight_reprs = []
207+
algorithm.init_training_weights(dense_layer_weights[1])
208+
w1_repr, w2_repr = algorithm.weight_reprs
209+
210+
w1 = w1_repr.kwargs['initializer'](
211+
shape=None, dtype=w1_repr.kwargs['dtype'])
212+
w2 = w2_repr.kwargs['initializer'](
213+
shape=None, dtype=w2_repr.kwargs['dtype'])
211214

212215
assert (w1 == dense_layer_compressed_weights[0]).numpy().all()
213216
assert (w2 == dense_layer_compressed_weights[1]).numpy().all()

0 commit comments

Comments
 (0)