Skip to content

Commit 15eccc0

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Weight compression API implementation for simplest case where original weights and graph
differ from training/inference weights and graph, but training/inference graphs are the same. Test that weights are converted and that pretrained weights are preserved. The TFLite prevention of constant folding currently doesn't work. PiperOrigin-RevId: 338138622
1 parent 99b3fd6 commit 15eccc0

File tree

8 files changed

+760
-0
lines changed

8 files changed

+760
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
licenses(["notice"])
2+
3+
py_library(
4+
name = "algorithm",
5+
srcs = ["algorithm.py"],
6+
srcs_version = "PY3",
7+
visibility = ["//visibility:public"],
8+
deps = [
9+
"//tensorflow_model_optimization/python/core/common/keras/compression/internal:optimize",
10+
#TODO(tfmot): remove when we stick to Python 3.7, which includes this by default.
11+
# dataclasses dep1,
12+
# tensorflow dep1,
13+
],
14+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
These APIs are not ready for public usage yet.
2+
3+
algorithms/ : end-to-end tests that demonstrate usage of algorithm developer
4+
public API.
5+
6+
algorithm.py : public API for algorithm developer
7+
8+
internal/ : internal parts of library. These should not be used anywhere outside
9+
this directory.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Public APIs for algorithm developer using weight compression API."""
16+
import abc
17+
from typing import List, Any
18+
import dataclasses
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_model_optimization.python.core.common.keras.compression.internal import optimize
23+
24+
25+
@dataclasses.dataclass
26+
class WeightRepr:
27+
"""Dataclass that wraps `tf.keras.layers.Layer.add_weight` parameters."""
28+
name: Any = None
29+
shape: Any = None
30+
dtype: Any = None
31+
initializer: Any = None
32+
regularizer: Any = None
33+
trainable: Any = None
34+
constraint: Any = None
35+
partitioner: Any = None
36+
use_resource: Any = None
37+
synchronization: Any = tf.VariableSynchronization.AUTO
38+
aggregation: Any = tf.compat.v1.VariableAggregation.NONE
39+
40+
41+
class WeightCompressionAlgorithm(metaclass=abc.ABCMeta):
42+
"""Interface for weight compression algorithm that acts on a per-layer basis.
43+
44+
This allows both options of either decompressing during inference or
45+
decompressing prior to inference (where compression occurs by applying a
46+
tool such as zip to the model file).
47+
48+
This interface is a purely functional one.
49+
"""
50+
51+
@abc.abstractmethod
52+
def init_training_weights_repr(
53+
self, pretrained_weight: tf.Tensor) -> List[WeightRepr]:
54+
"""Create training weight representations for initializing layer variables.
55+
56+
Args:
57+
pretrained_weight: tf.Tensor of a pretrained weight of a layer that will
58+
be compressed eventually.
59+
60+
Returns:
61+
A list of `WeightRepr`, a container for arguments to
62+
`tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
63+
"""
64+
65+
def decompress(self, compressed_weights: List[tf.Tensor]) -> tf.Tensor:
66+
"""Define the operations to decompress a single weight’s compressed form during inference.
67+
68+
The default is an identity. TODO(): actually isn't.
69+
70+
Args:
71+
compressed_weights: tf.Tensors representing a single weight’s compressed
72+
form, coming from what’s returned in `compress`.
73+
74+
Returns:
75+
A tf.Tensor representing the decompressed `compressed_weights`.
76+
"""
77+
return compressed_weights[0]
78+
79+
@abc.abstractmethod
80+
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
81+
"""Define a piece of the forward pass during training, which operates on a single compressible weight.
82+
83+
The default throws an error when training occurs.
84+
85+
Args:
86+
training_weights: tf.Tensors representing any variables used during
87+
training, for a single compressible weight, in the order returned in
88+
`init_training_weights_repr`.
89+
90+
Returns:
91+
tf.Tensor to set the compressible weight to.
92+
"""
93+
94+
95+
def create_layer_for_training(
96+
layer: tf.keras.layers.Layer,
97+
algorithm: WeightCompressionAlgorithm) -> tf.keras.layers.Layer:
98+
return optimize.create_layer_for_training(layer, algorithm)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package(default_visibility = ["//visibility:private"])
2+
3+
licenses(["notice"])
4+
5+
py_library(
6+
name = "same_training_and_inference",
7+
srcs = ["same_training_and_inference.py"],
8+
srcs_version = "PY3",
9+
deps = [
10+
# tensorflow dep1,
11+
"//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
12+
],
13+
)
14+
15+
py_test(
16+
name = "same_training_and_inference_test",
17+
srcs = ["same_training_and_inference_test.py"],
18+
python_version = "PY3",
19+
deps = [
20+
":same_training_and_inference",
21+
# numpy dep1,
22+
# tensorflow dep1,
23+
],
24+
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""SVD algorithm, where the training and inference graphs are the same."""
16+
from typing import List
17+
18+
import tensorflow as tf
19+
20+
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
21+
22+
23+
class SVDParams(object):
24+
"""Define container for parameters for SVD algorithm."""
25+
26+
def __init__(self, rank):
27+
self.rank = rank
28+
29+
30+
class SVD(algorithm.WeightCompressionAlgorithm):
31+
"""Define how to apply SVD algorithm."""
32+
33+
def __init__(self, params):
34+
self.params = params
35+
36+
# TODO(tfmot): communicate that `pretrained_weight` will sometimes
37+
# be a dummy tensor and sometimes be actual pretrained values during
38+
# its actual usage.
39+
def init_training_weights_repr(
40+
self, pretrained_weight: tf.Tensor) -> List[algorithm.WeightRepr]:
41+
rank = self.params.rank
42+
s, u, v = tf.linalg.svd(pretrained_weight)
43+
44+
if len(pretrained_weight.shape) == 2:
45+
# FC Layer
46+
s = s[:rank]
47+
u = u[:, :rank]
48+
v = v[:, :rank]
49+
elif len(pretrained_weight.shape) == 4:
50+
# Conv2D Layer
51+
s = s[:, :, :rank]
52+
u = u[:, :, :, :rank]
53+
v = v[:, :, :, :rank]
54+
else:
55+
raise NotImplementedError('Only for dimension=2 or 4 is supported.')
56+
57+
sv = tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)
58+
59+
# TODO(tfmot): note that it does not suffice to just have the initializer
60+
# to derive the shape from, in the case of a constant initializer.
61+
# The unit test fail without providing the shape.
62+
return [
63+
algorithm.WeightRepr(
64+
name='u',
65+
shape=u.shape,
66+
initializer=tf.keras.initializers.Constant(u)),
67+
algorithm.WeightRepr(
68+
name='sv',
69+
shape=sv.shape,
70+
initializer=tf.keras.initializers.Constant(sv))
71+
]
72+
73+
def decompress(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
74+
return tf.matmul(u, sv)
75+
76+
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
77+
u = training_weights[0]
78+
sv = training_weights[1]
79+
return self.decompress(u, sv)
80+
81+
82+
def optimize(to_optimize: tf.keras.Model, params: SVDParams) -> tf.keras.Model:
83+
"""Model developer API for optimizing a model."""
84+
85+
def _optimize_layer(layer):
86+
# Require layer to be built so that the SVD-factorized weights
87+
# can be initialized from the weights.
88+
if not layer.built:
89+
raise ValueError(
90+
'Applying SVD currently requires passing in a built model')
91+
92+
return algorithm.create_layer_for_training(layer, algorithm=SVD(params))
93+
94+
return tf.keras.models.clone_model(
95+
to_optimize, clone_function=_optimize_layer)

0 commit comments

Comments
 (0)