Skip to content

Commit 2e85bc3

Browse files
Xharktensorflower-gardener
authored andcommitted
Implement update_training_weight functionality and add an example for it.
PiperOrigin-RevId: 368583345
1 parent 23b274b commit 2e85bc3

File tree

5 files changed

+476
-8
lines changed

5 files changed

+476
-8
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class WeightCompressor(metaclass=abc.ABCMeta):
3737
3838
This interface is a purely functional one.
3939
"""
40+
update_ops = [] # type: List
4041

4142
# TODO(tfmot): Consider separate from algorithm API for custom layer supports.
4243
def get_compressible_weights(
@@ -103,19 +104,40 @@ def project_training_weights(
103104
tf.Tensor to set the compressible weight to.
104105
"""
105106

107+
def init_update_ops(self, tensor_weight_pairs):
108+
self.update_ops = []
109+
self.tensor_weight_pairs = tensor_weight_pairs
110+
106111
def update_training_weight(
107-
self, training_weight: tf.Tensor, tensor: tf.Tensor):
108-
"""Update a training weight to a given tensor value.
112+
self, training_weight: tf.Tensor, value: tf.Tensor):
113+
"""Add training weight assign op to the model update list.
114+
115+
This method is for the case that training weight should update to a
116+
specific value not from the model optimizer. It will throw an error if it
117+
can't find the training weight.
109118
110-
This method is for the case that training weight should update to a specific
111-
value not from the model optimizer. It will throw an error if it can't
112-
find the training weight.
119+
This method should called in project_training_weights. During the training,
120+
We collect all update_training_weight calls and make an UpdateOp for each
121+
call. Finally, we put all these update ops to model.add_update.
113122
114123
Args:
115124
training_weight: tf.Tensor representing a training weight.
116-
tensor: tf.Tensor representing a value to be assigned to the training
125+
value: tf.Tensor representing a value to be assigned to the training
117126
weight.
127+
Raises:
128+
ValueError if it can't find the training weight.
118129
"""
130+
for tensor, weight in self.tensor_weight_pairs:
131+
if training_weight is tensor:
132+
self.update_ops.append(weight.assign(value))
133+
return
134+
135+
raise ValueError('Training weight not found. Please call '
136+
'the update_training_weight with given training '
137+
'weight tensor.')
138+
139+
def get_update_ops(self):
140+
return self.update_ops
119141

120142
def compress_training_weights(
121143
self, *training_weights: tf.Tensor) -> List[tf.Tensor]:

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,25 @@ py_test(
9090
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
9191
],
9292
)
93+
94+
py_library(
95+
name = "periodical_update_and_scheduling",
96+
srcs = ["periodical_update_and_scheduling.py"],
97+
srcs_version = "PY3",
98+
deps = [
99+
# tensorflow dep1,
100+
"//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
101+
],
102+
)
103+
104+
py_test(
105+
name = "periodical_update_and_scheduling_test",
106+
timeout = "long",
107+
srcs = ["periodical_update_and_scheduling_test.py"],
108+
python_version = "PY3",
109+
deps = [
110+
":periodical_update_and_scheduling",
111+
# numpy dep1,
112+
# tensorflow dep1,
113+
],
114+
)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""SVD algorithm, where the training and inference graphs are the same."""
16+
from typing import List
17+
18+
import tensorflow as tf
19+
20+
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
21+
22+
23+
class SVD(algorithm.WeightCompressor):
24+
"""Define how to apply SVD algorithm.
25+
26+
This periodic update and scheduling base SVD algorithm update the original
27+
weights to make lower rank by SVD for each update_freq steps. During the
28+
warmup steps, It adjust the rank from the original to target rank gradually.
29+
"""
30+
31+
def __init__(self, rank, update_freq=100, warmup_step=1000):
32+
self.rank = rank
33+
self.update_freq = update_freq
34+
self.warmup_step = warmup_step
35+
36+
# TODO(tfmot): communicate that `pretrained_weight` will sometimes
37+
# be a dummy tensor and sometimes be actual pretrained values during
38+
# its actual usage.
39+
def init_training_weights(
40+
self, pretrained_weight: tf.Tensor):
41+
self.add_training_weight(
42+
name='w',
43+
shape=pretrained_weight.shape,
44+
dtype=pretrained_weight.dtype,
45+
initializer=tf.keras.initializers.Constant(pretrained_weight))
46+
self.add_training_weight(
47+
name='step',
48+
shape=(),
49+
dtype=tf.int32,
50+
initializer=tf.keras.initializers.Constant(0))
51+
52+
def decompress_weights(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
53+
return tf.matmul(u, sv)
54+
55+
def project_training_weights(
56+
self, weight: tf.Tensor, step: tf.Tensor) -> tf.Tensor:
57+
weight_rank = tf.math.minimum(weight.shape[-1], weight.shape[-2])
58+
self.update_training_weight(step, step + 1)
59+
if step % self.update_freq == 0:
60+
rank = self.rank
61+
if step < self.warmup_step:
62+
rank = tf.cast(tf.math.round(
63+
weight_rank * (self.warmup_step - step)
64+
+ self.rank * step
65+
) / self.warmup_step, tf.int32)
66+
rank = tf.math.minimum(rank, weight_rank)
67+
68+
s, u, v = tf.linalg.svd(weight)
69+
70+
if len(weight.shape) == 2:
71+
# FC Layer
72+
s = s[:rank]
73+
u = u[:, :rank]
74+
v = v[:, :rank]
75+
elif len(weight.shape) == 4:
76+
# Conv2D Layer
77+
s = s[:, :, :rank]
78+
u = u[:, :, :, :rank]
79+
v = v[:, :, :, :rank]
80+
else:
81+
raise NotImplementedError('Only for dimension=2 or 4 is supported.')
82+
83+
sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)
84+
85+
new_weight = tf.matmul(u, sv)
86+
self.update_training_weight(weight, new_weight)
87+
88+
return weight
89+
90+
def compress_training_weights(self, weight: tf.Tensor, _) -> List[tf.Tensor]:
91+
rank = self.rank
92+
s, u, v = tf.linalg.svd(weight)
93+
94+
if len(weight.shape) == 2:
95+
# FC Layer
96+
s = s[:rank]
97+
u = u[:, :rank]
98+
v = v[:, :rank]
99+
elif len(weight.shape) == 4:
100+
# Conv2D Layer
101+
s = s[:, :, :rank]
102+
u = u[:, :, :, :rank]
103+
v = v[:, :, :, :rank]
104+
else:
105+
raise NotImplementedError('Only for dimension=2 or 4 is supported.')
106+
107+
sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)
108+
109+
return [u, sv]
110+
111+
def get_compressible_weights(
112+
self, original_layer: tf.keras.layers.Layer) -> List[str]:
113+
if isinstance(original_layer, (tf.keras.layers.Conv2D,
114+
tf.keras.layers.Dense)):
115+
return [original_layer.kernel]
116+
return []
117+
118+
def optimize_model(self, to_optimize: tf.keras.Model) -> tf.keras.Model:
119+
"""Model developer API for optimizing a model for training.
120+
121+
The returned model should be used for compression aware training.
122+
Args:
123+
to_optimize: The model to be optimize.
124+
Returns:
125+
A wrapped model that has compression optimizers.
126+
"""
127+
# pylint: disable=protected-access
128+
if not isinstance(
129+
to_optimize, tf.keras.Sequential) and not to_optimize._is_graph_network:
130+
raise ValueError(
131+
'`optimize_model` can only either be a tf.keras Sequential or '
132+
'Functional model.')
133+
# pylint: enable=protected-access
134+
135+
def _optimize_layer(layer):
136+
# Require layer to be built so that the SVD-factorized weights
137+
# can be initialized from the weights.
138+
if not layer.built:
139+
raise ValueError(
140+
'Applying SVD currently requires passing in a built model')
141+
142+
return algorithm.create_layer_for_training(layer, algorithm=self)
143+
144+
return tf.keras.models.clone_model(
145+
to_optimize, clone_function=_optimize_layer)
146+
147+
def compress_model(self, to_compress: tf.keras.Model) -> tf.keras.Model:
148+
"""Model developer API for optimizing a model for inference.
149+
150+
Args:
151+
to_compress: The model that trained for compression. This model should
152+
generated from the `optimize_model` method.
153+
Returns:
154+
A compressed model for the inference.
155+
"""
156+
def _optimize_layer(layer):
157+
# Require layer to be built so that the SVD-factorized weights
158+
# can be initialized from the weights.
159+
if not layer.built:
160+
raise ValueError(
161+
'Applying SVD currently requires passing in a built model')
162+
163+
return algorithm.create_layer_for_inference(layer, algorithm=self)
164+
165+
return tf.keras.models.clone_model(
166+
to_compress, clone_function=_optimize_layer)

0 commit comments

Comments
 (0)