Skip to content

Commit 8c875fc

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Increase algorithm coverage for algorithms that need to modify the weights after training (compress API) or have an inference graph that differs from the training graph (decompress vs training).
PiperOrigin-RevId: 338149184
1 parent 15eccc0 commit 8c875fc

File tree

5 files changed

+563
-2
lines changed

5 files changed

+563
-2
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ def init_training_weights_repr(
6262
`tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
6363
"""
6464

65+
def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
66+
"""Define the operations to compress a single weight after training.
67+
68+
'Compress' can refer to making the weight more amenable to compression
69+
or actually compress the weight.
70+
71+
The default is an identity.
72+
73+
Args:
74+
training_weights: tf.Tensors representing all variables used during
75+
training, for a single compressible weight, in the order returned in
76+
`init_training_weights_repr`.
77+
78+
Returns:
79+
List of tf.Tensors to set to compressed or more compressible form.
80+
"""
81+
return training_weights
82+
6583
def decompress(self, compressed_weights: List[tf.Tensor]) -> tf.Tensor:
6684
"""Define the operations to decompress a single weight’s compressed form during inference.
6785
@@ -80,6 +98,7 @@ def decompress(self, compressed_weights: List[tf.Tensor]) -> tf.Tensor:
8098
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
8199
"""Define a piece of the forward pass during training, which operates on a single compressible weight.
82100
101+
TODO(tfmot): throw this error.
83102
The default throws an error when training occurs.
84103
85104
Args:
@@ -96,3 +115,9 @@ def create_layer_for_training(
96115
layer: tf.keras.layers.Layer,
97116
algorithm: WeightCompressionAlgorithm) -> tf.keras.layers.Layer:
98117
return optimize.create_layer_for_training(layer, algorithm)
118+
119+
120+
def create_layer_for_inference(
121+
layer_for_training: tf.keras.layers.Layer,
122+
algorithm: WeightCompressionAlgorithm) -> tf.keras.layers.Layer:
123+
return optimize.create_layer_for_inference(layer_for_training, algorithm)

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,24 @@ py_test(
2222
# tensorflow dep1,
2323
],
2424
)
25+
26+
py_library(
27+
name = "different_training_and_inference",
28+
srcs = ["different_training_and_inference.py"],
29+
srcs_version = "PY3",
30+
deps = [
31+
# tensorflow dep1,
32+
"//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
33+
],
34+
)
35+
36+
py_test(
37+
name = "different_training_and_inference_test",
38+
srcs = ["different_training_and_inference_test.py"],
39+
python_version = "PY3",
40+
deps = [
41+
":different_training_and_inference",
42+
# numpy dep1,
43+
# tensorflow dep1,
44+
],
45+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""SVD algorithm, where the training and inference graphs are different."""
16+
from typing import List
17+
18+
import tensorflow as tf
19+
20+
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
21+
22+
23+
class SVDParams(object):
24+
"""Define container for parameters for SVD algorithm."""
25+
26+
def __init__(self, rank):
27+
self.rank = rank
28+
29+
30+
class SVD(algorithm.WeightCompressionAlgorithm):
31+
"""Define how to apply SVD algorithm."""
32+
33+
def __init__(self, params):
34+
self.params = params
35+
36+
# TODO(tfmot): communicate that `pretrained_weight` will sometimes
37+
# be a dummy tensor and sometimes be actual pretrained values during
38+
# its actual usage.
39+
def init_training_weights_repr(
40+
self, pretrained_weight: tf.Tensor) -> List[algorithm.WeightRepr]:
41+
return [
42+
algorithm.WeightRepr(
43+
name='w',
44+
shape=pretrained_weight.shape,
45+
initializer=tf.keras.initializers.Constant(pretrained_weight))
46+
]
47+
48+
def decompress(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
49+
return tf.matmul(u, sv)
50+
51+
def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
52+
assert len(training_weights) == 1
53+
weight = training_weights[0]
54+
55+
rank = self.params.rank
56+
s, u, v = tf.linalg.svd(weight)
57+
58+
if len(weight.shape) == 2:
59+
# FC Layer
60+
s = s[:rank]
61+
u = u[:, :rank]
62+
v = v[:, :rank]
63+
elif len(weight.shape) == 4:
64+
# Conv2D Layer
65+
s = s[:, :, :rank]
66+
u = u[:, :, :, :rank]
67+
v = v[:, :, :, :rank]
68+
else:
69+
raise NotImplementedError('Only for dimension=2 or 4 is supported.')
70+
71+
sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)
72+
73+
return [u, sv]
74+
75+
# TODO(tfmot): remove in this example, which is just post-training.
76+
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
77+
return training_weights[0]
78+
79+
80+
# TODO(tfmot): consider if we can simplify `create_model_for_training` and
81+
# `create_model_for_inference` into a single API for algorithm developers.
82+
def optimize(to_optimize: tf.keras.Model, params: SVDParams) -> tf.keras.Model:
83+
"""Model developer API for optimizing a model."""
84+
85+
def _create_layer_for_training(layer):
86+
# Require layer to be built so that the SVD-factorized weights
87+
# can be initialized from the weights.
88+
if not layer.built:
89+
raise ValueError(
90+
'Applying SVD currently requires passing in a built model')
91+
92+
return algorithm.create_layer_for_training(layer, algorithm=SVD(params))
93+
94+
def _create_layer_for_inference(layer):
95+
return algorithm.create_layer_for_inference(layer, algorithm=SVD(params))
96+
97+
intermediate_model = tf.keras.models.clone_model(
98+
to_optimize, clone_function=_create_layer_for_training)
99+
100+
return tf.keras.models.clone_model(
101+
intermediate_model, clone_function=_create_layer_for_inference)

0 commit comments

Comments
 (0)