Skip to content

Commit c05ce9e

Browse files
Xharktensorflower-gardener
authored andcommitted
Add weight clustering implementation using compression API.
It requires clustering modules in tfmot weight clustering implementaion. Added dtype for init variable due to a compressed weight dtype is int64. On example code, gzip compressed tflite file size is reduced from 1592595 bytes to 224167 bytes. (14%) PiperOrigin-RevId: 341767464
1 parent 3c2c2e3 commit c05ce9e

File tree

3 files changed

+301
-0
lines changed

3 files changed

+301
-0
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,26 @@ py_test(
6767
# tensorflow dep1,
6868
],
6969
)
70+
71+
py_library(
72+
name = "weight_clustering",
73+
srcs = ["weight_clustering.py"],
74+
srcs_version = "PY3",
75+
deps = [
76+
# tensorflow dep1,
77+
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_centroids",
78+
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
79+
"//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
80+
],
81+
)
82+
83+
py_test(
84+
name = "weight_clustering_test",
85+
srcs = ["weight_clustering_test.py"],
86+
python_version = "PY3",
87+
deps = [
88+
":weight_clustering",
89+
# tensorflow dep1,
90+
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
91+
],
92+
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Weight clustering algorithm using tfmot compression api."""
16+
from typing import List
17+
18+
import tensorflow as tf
19+
20+
# TODO(tfmot): Make sure weight clustering APIs can be used in this place or
21+
# move the APIs into the same directory.
22+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
23+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
24+
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
25+
26+
27+
class WeightClusteringParams(object):
28+
"""Weight clustering parameters."""
29+
30+
def __init__(self,
31+
number_of_clusters,
32+
cluster_centroids_init):
33+
self.number_of_clusters = number_of_clusters
34+
self.cluster_centroids_init = cluster_centroids_init
35+
36+
37+
class WeightClustering(algorithm.WeightCompressionAlgorithm):
38+
"""Weight clustering compression module config."""
39+
40+
def __init__(self, params):
41+
self.params = params
42+
43+
def init_training_weights_repr(
44+
self, pretrained_weight: tf.Tensor) -> List[algorithm.WeightRepr]:
45+
"""Init function from pre-trained model case."""
46+
centroid_initializer = clustering_centroids.CentroidsInitializerFactory.\
47+
get_centroid_initializer(
48+
self.params.cluster_centroids_init
49+
)(pretrained_weight, self.params.number_of_clusters)
50+
51+
cluster_centroids = centroid_initializer.get_cluster_centroids()
52+
53+
if len(pretrained_weight.shape) == 2:
54+
clustering_impl_cls = clustering_registry.DenseWeightsCA
55+
elif len(pretrained_weight.shape) == 4:
56+
clustering_impl_cls = clustering_registry.ConvolutionalWeightsCA
57+
else:
58+
raise NotImplementedError('Only for dimension=2 or 4 is supported.')
59+
60+
clustering_impl = clustering_impl_cls(
61+
cluster_centroids
62+
)
63+
64+
# We find the nearest cluster centroids and store them so that ops can
65+
# build their weights upon it. These indices are calculated once and
66+
# stored forever. We use to make look-ups from self.cluster_centroids_tf
67+
pulling_indices = clustering_impl.get_pulling_indices(pretrained_weight)
68+
69+
return [
70+
algorithm.WeightRepr(
71+
name='cluster_centroids',
72+
shape=cluster_centroids.shape,
73+
dtype=cluster_centroids.dtype,
74+
initializer=tf.keras.initializers.Constant(cluster_centroids)),
75+
algorithm.WeightRepr(
76+
name='pulling_indices',
77+
shape=pulling_indices.shape,
78+
dtype=pulling_indices.dtype,
79+
initializer=tf.keras.initializers.Constant(pulling_indices))
80+
]
81+
82+
def decompress(self,
83+
cluster_centroids: tf.Tensor,
84+
pulling_indices: tf.Tensor) -> tf.Tensor:
85+
return tf.reshape(
86+
tf.gather(cluster_centroids,
87+
tf.reshape(pulling_indices, shape=(-1,))),
88+
pulling_indices.shape)
89+
90+
def training(self,
91+
cluster_centroids: tf.Tensor,
92+
pulling_indices: tf.Tensor) -> tf.Tensor:
93+
return self.decompress(cluster_centroids, pulling_indices)
94+
95+
def get_compressible_weights(
96+
self, original_layer: tf.keras.layers.Layer) -> List[str]:
97+
if isinstance(original_layer, tf.keras.layers.Conv2D) or \
98+
isinstance(original_layer, tf.keras.layers.Dense):
99+
return ['kernel']
100+
return []
101+
102+
103+
def optimize(
104+
to_optimize: tf.keras.Model,
105+
params: WeightClusteringParams) -> tf.keras.Model:
106+
"""Model developer API for optimizing a model."""
107+
108+
def _optimize_layer(layer):
109+
# Require layer to be built so that the SVD-factorized weights
110+
# can be initialized from the weights.
111+
if not layer.built:
112+
raise ValueError(
113+
'Applying weight clustering currently '
114+
'requires passing in a built model')
115+
116+
return algorithm.create_layer_for_training(
117+
layer, algorithm=WeightClustering(params))
118+
119+
return tf.keras.models.clone_model(
120+
to_optimize, clone_function=_optimize_layer)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for weight clustering algorithm."""
16+
17+
import os
18+
import tempfile
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
23+
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import weight_clustering
24+
25+
26+
def _build_model():
27+
i = tf.keras.layers.Input(shape=(28, 28), name='input')
28+
x = tf.keras.layers.Reshape((28, 28, 1))(i)
29+
x = tf.keras.layers.Conv2D(
30+
20, 5, activation='relu', padding='valid', name='conv1')(
31+
x)
32+
x = tf.keras.layers.MaxPool2D(2, 2)(x)
33+
x = tf.keras.layers.Conv2D(
34+
50, 5, activation='relu', padding='valid', name='conv2')(
35+
x)
36+
x = tf.keras.layers.MaxPool2D(2, 2)(x)
37+
x = tf.keras.layers.Flatten()(x)
38+
x = tf.keras.layers.Dense(500, activation='relu', name='fc1')(x)
39+
output = tf.keras.layers.Dense(10, name='fc2')(x)
40+
41+
model = tf.keras.Model(inputs=[i], outputs=[output])
42+
return model
43+
44+
45+
def _get_dataset():
46+
mnist = tf.keras.datasets.mnist
47+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
48+
x_train, x_test = x_train / 255.0, x_test / 255.0
49+
# Use subset of 60000 examples to keep unit test speed fast.
50+
x_train = x_train[:1000]
51+
y_train = y_train[:1000]
52+
53+
return (x_train, y_train), (x_test, y_test)
54+
55+
56+
def _train_model(model):
57+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
58+
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
59+
(x_train, y_train), _ = _get_dataset()
60+
model.fit(x_train, y_train, epochs=1)
61+
62+
63+
def _save_as_saved_model(model):
64+
saved_model_dir = tempfile.mkdtemp()
65+
model.save(saved_model_dir)
66+
return saved_model_dir
67+
68+
69+
def _get_directory_size_in_bytes(directory):
70+
total = 0
71+
try:
72+
for entry in os.scandir(directory):
73+
if entry.is_file():
74+
# if it's a file, use stat() function
75+
total += entry.stat().st_size
76+
elif entry.is_dir():
77+
# if it's a directory, recursively call this function
78+
total += _get_directory_size_in_bytes(entry.path)
79+
except NotADirectoryError:
80+
# if `directory` isn't a directory, get the file size then
81+
return os.path.getsize(directory)
82+
except PermissionError:
83+
# if for whatever reason we can't open the folder, return 0
84+
return 0
85+
return total
86+
87+
88+
class FunctionalTest(tf.test.TestCase):
89+
90+
def testWeightClustering_TrainingE2E(self):
91+
number_of_clusters = 8
92+
model = _build_model()
93+
_train_model(model)
94+
original_saved_model_dir = _save_as_saved_model(model)
95+
96+
params = weight_clustering.WeightClusteringParams(
97+
number_of_clusters=number_of_clusters,
98+
cluster_centroids_init=\
99+
cluster_config.CentroidInitialization.DENSITY_BASED)
100+
compressed_model = weight_clustering.optimize(model, params)
101+
102+
_train_model(compressed_model)
103+
104+
saved_model_dir = _save_as_saved_model(compressed_model)
105+
106+
_, (x_test, y_test) = _get_dataset()
107+
108+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
109+
110+
compressed_model.compile(
111+
optimizer='adam', loss=loss_fn, metrics=['accuracy'])
112+
113+
results = compressed_model.evaluate(x_test, y_test)
114+
115+
# Accuracy test.
116+
self.assertGreater(results[1], 0.85) # 0.8708
117+
118+
original_size = _get_directory_size_in_bytes(original_saved_model_dir)
119+
compressed_size = _get_directory_size_in_bytes(saved_model_dir)
120+
121+
# Compressed model size test.
122+
# TODO(tfmot): gzip compression can reduce file size much better.
123+
self.assertLess(compressed_size, original_size / 1.3)
124+
125+
def testWeightClustering_SingleLayer(self):
126+
number_of_clusters = 8
127+
i = tf.keras.layers.Input(shape=(2), name='input')
128+
output = tf.keras.layers.Dense(3, name='fc1')(i)
129+
model = tf.keras.Model(inputs=[i], outputs=[output])
130+
131+
dense_layer_weights = model.layers[1].get_weights()
132+
133+
params = weight_clustering.WeightClusteringParams(
134+
number_of_clusters=number_of_clusters,
135+
cluster_centroids_init=\
136+
cluster_config.CentroidInitialization.DENSITY_BASED)
137+
compressed_model = weight_clustering.optimize(model, params)
138+
139+
dense_layer_compressed_weights = compressed_model.layers[1].get_weights()
140+
141+
# clustering_centroids.
142+
self.assertEqual(
143+
dense_layer_compressed_weights[0].shape, (number_of_clusters,))
144+
145+
# pulling_indices.
146+
self.assertEqual(
147+
dense_layer_compressed_weights[1].shape,
148+
dense_layer_weights[0].shape)
149+
self.assertEqual(str(dense_layer_compressed_weights[1].dtype), 'int64')
150+
self.assertAllInRange(
151+
dense_layer_compressed_weights[1], 0, number_of_clusters - 1)
152+
153+
# bias
154+
assert (dense_layer_weights[1] == dense_layer_compressed_weights[2]).all()
155+
156+
157+
if __name__ == '__main__':
158+
tf.test.main()

0 commit comments

Comments
 (0)