Skip to content

Commit eec0e4c

Browse files
Merge pull request #616 from wwwind:clusterable_layer
PiperOrigin-RevId: 368158644
2 parents 3d73ed2 + cd9b2a3 commit eec0e4c

File tree

10 files changed

+805
-163
lines changed

10 files changed

+805
-163
lines changed

tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope
2020
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights
2121
from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering
22+
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2223

2324
from tensorflow_model_optimization.python.core.clustering.keras.cluster_config import CentroidInitialization
25+
from tensorflow_model_optimization.python.core.clustering.keras.clustering_algorithm import AbstractClusteringAlgorithm
2426
from tensorflow_model_optimization.python.core.clustering.keras.clustering_callbacks import ClusteringSummaries
27+
from tensorflow_model_optimization.python.core.clustering.keras.clusterable_layer import ClusterableLayer
2528
# pylint: enable=g-bad-import-order

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ py_library(
4444
visibility = ["//visibility:public"],
4545
deps = [
4646
":clusterable_layer",
47+
":clustering_algorithm",
4748
],
4849
)
4950

@@ -74,6 +75,16 @@ py_library(
7475
],
7576
)
7677

78+
py_library(
79+
name = "clustering_algorithm",
80+
srcs = ["clustering_algorithm.py"],
81+
srcs_version = "PY3",
82+
visibility = ["//visibility:public"],
83+
deps = [
84+
# tensorflow dep1,
85+
],
86+
)
87+
7788
py_library(
7889
name = "clustering_callbacks",
7990
srcs = ["clustering_callbacks.py"],
@@ -165,3 +176,25 @@ py_test(
165176
"//tensorflow_model_optimization/python/core/keras:test_utils",
166177
],
167178
)
179+
180+
py_test(
181+
name = "mnist_clusterable_layer_test",
182+
srcs = ["mnist_clusterable_layer_test.py"],
183+
python_version = "PY3",
184+
visibility = ["//visibility:public"],
185+
deps = [
186+
":cluster",
187+
# tensorflow dep1,
188+
],
189+
)
190+
191+
py_test(
192+
name = "mnist_clustering_test",
193+
srcs = ["mnist_clustering_test.py"],
194+
python_version = "PY3",
195+
visibility = ["//visibility:public"],
196+
deps = [
197+
":cluster",
198+
# tensorflow dep1,
199+
],
200+
)

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,64 @@ def get_clusterable_weights(self):
5454
class CustomNonClusterableLayer(layers.Dense):
5555
pass
5656

57+
5758
class KerasCustomLayer(keras.layers.Layer):
59+
5860
def __init__(self, units=32):
5961
super(KerasCustomLayer, self).__init__()
6062
self.units = units
6163

6264
def build(self, input_shape):
6365
self.w = self.add_weight(
64-
shape=(input_shape[-1], self.units),
65-
initializer="random_normal",
66-
trainable=True,
66+
shape=(input_shape[-1], self.units),
67+
initializer='random_normal',
68+
trainable=True,
6769
)
6870
self.b = self.add_weight(
69-
shape=(self.units,),
70-
initializer="random_normal",
71-
trainable=False
72-
)
71+
shape=(self.units,), initializer='random_normal', trainable=False)
7372

7473
def call(self, inputs):
7574
return tf.matmul(inputs, self.w) + self.b
7675

76+
77+
class MyClusterableLayer(keras.layers.Dense,
78+
clusterable_layer.ClusterableLayer):
79+
80+
def get_clusterable_weights(self):
81+
# Cluster kernel and bias.
82+
return [('kernel', self.kernel), ('bias', self.bias)]
83+
84+
85+
class MyClusterableLayerInvalid(keras.layers.Dense,
86+
clusterable_layer.ClusterableLayer):
87+
"""This layer is invalid: it does not implement get_clusterable_weights(self).
88+
"""
89+
pass
90+
91+
92+
class TestCustomerableWeightsCA(clustering_registry.AbstractClusteringAlgorithm
93+
):
94+
"""Dummy class derived from AbstractClusteringAlgorithm."""
95+
96+
def get_pulling_indices(self, weight):
97+
return [1, 2, 3]
98+
99+
100+
class KerasCustomLayerClusterable(keras.layers.Layer,
101+
clusterable_layer.ClusterableLayer):
102+
"""Custom Keras clusterable layer, providing its own clustering algorithm."""
103+
104+
def __init__(self):
105+
super().__init__()
106+
self.kernel = None
107+
108+
def get_clusterable_weights(self):
109+
return [('kernel', self.kernel)]
110+
111+
def get_clusterable_algorithm(self, weight_name):
112+
return TestCustomerableWeightsCA
113+
114+
77115
class ClusterTest(test.TestCase, parameterized.TestCase):
78116
"""Unit tests for the cluster module."""
79117

@@ -87,6 +125,7 @@ def setUp(self):
87125
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
88126
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
89127
self.keras_custom_layer = KerasCustomLayer()
128+
self.clusterable_layer = MyClusterableLayer(10)
90129

91130
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
92131
{
@@ -204,24 +243,37 @@ def testClusterCustomNonClusterableLayer(self):
204243
**self.params)
205244

206245
def testClusterKerasCustomLayer(self):
207-
"""
208-
Verifies that attempting to cluster a keras custom layer raises
209-
an exception.
210-
"""
246+
"""Verifies that attempting to cluster a keras custom layer raises an exception."""
211247
# If layer is not built, it has not weights, so
212248
# we just skip it.
213249
keras_custom_layer = self.keras_custom_layer
214-
cluster_wrapper.ClusterWeights(keras_custom_layer,
215-
**self.params)
250+
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
216251
# We need to build weights before check that clustering is not supported.
217252
keras_custom_layer.build(input_shape=(10, 10))
218253
with self.assertRaises(ValueError):
219-
cluster_wrapper.ClusterWeights(keras_custom_layer,
220-
**self.params)
254+
cluster_wrapper.ClusterWeights(keras_custom_layer, **self.params)
255+
256+
def testClusterMyClusterableLayer(self):
257+
# we have weights to cluster.
258+
layer = self.clusterable_layer
259+
layer.build(input_shape=(10, 10))
260+
261+
wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params)
262+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
263+
264+
def testKerasCustomLayerClusterable(self):
265+
"""Verifies that we can wrap keras custom layer that is customerable."""
266+
layer = KerasCustomLayerClusterable()
267+
wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params)
268+
self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
269+
270+
def testClusterMyClusterableLayerInvalid(self):
271+
"""Verifies that an exception is raised when get_clusterable_weights() is not implemented."""
272+
with self.assertRaises(TypeError):
273+
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated
221274

222275
@keras_parameterized.run_all_keras_modes
223276
def testClusterSequentialModelSelectively(self):
224-
"""Verifies that layers within a sequential model can be clustered selectively."""
225277
clustered_model = keras.Sequential()
226278
clustered_model.add(
227279
cluster.cluster_weights(self.keras_clusterable_layer, **self.params))

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def build(self, input_shape):
178178
# variable is either in the self._trainable_weights or in
179179
# self._non_trainable_weights and self.weights is the result of
180180
# concatenation of those arrays
181-
original_index = self.layer.weights.index(weight)
181+
original_index = 0
182+
for i in range(len(self.layer.weights)):
183+
if self.layer.weights[i].name == weight.name:
184+
original_index = i
182185
self.gone_variables.append(original_index)
183186

184187
# Again, not sure if this is needed. Leaving for now.

tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
class ClusterableLayer:
2323
"""Abstract Base Class for making your own keras layer clusterable.
2424
25-
Custom keras layers that need to support clustering should implement this
26-
class.
25+
Your layer could be derived from a keras built-in layer or
26+
it could be a keras custom layer.
27+
28+
The function get_clusterable_weights should be provided in both cases.
29+
30+
The function get_clusterable_algorithm is provided, when weights for
31+
clustering is added in the keras layer.
2732
2833
"""
2934

@@ -40,3 +45,22 @@ def get_clusterable_weights(self):
4045
kernel object itself.
4146
"""
4247
raise NotImplementedError('Must be implemented in subclasses.')
48+
49+
def get_clusterable_algorithm(self, weight_name): # pylint: disable=unused-argument
50+
"""Returns class with the clustering algorithm for the given weight_name.
51+
52+
This function needs to be implemented for the customerable layers.
53+
If the layer is derived from the built-in keras layer, the clustering
54+
algorithm for the base built-in keras layer is used.
55+
56+
The returned class should be derived from AbstractClusteringAlgorithm and
57+
implements the function get_pulling_indices.
58+
This function is used to provide a special lookup function for the custom
59+
weights.
60+
It reshapes and tile centroids the same way as the weights. This allows us
61+
to find pulling indices efficiently.
62+
63+
Args:
64+
weight_name ([string]): The name of the weight variable.
65+
"""
66+
return None
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Abstract base class for clustering algorithm."""
16+
17+
import abc
18+
import six
19+
import tensorflow as tf
20+
21+
22+
@six.add_metaclass(abc.ABCMeta)
23+
class AbstractClusteringAlgorithm(object):
24+
"""Abstrac class to implement highly efficient vectorised look-ups.
25+
26+
We do not utilise looping for that purpose, instead we `smartly` reshape and
27+
tile arrays. The trade-off is that we are potentially using way more memory
28+
than we would have if looping is used.
29+
30+
Each class that inherits from this class is supposed to implement a
31+
particular lookup function for a certain shape.
32+
33+
For example, look-ups for 2D table will be different in the case of 3D.
34+
"""
35+
36+
def __init__(self, clusters_centroids):
37+
"""Generating clustered tensors.
38+
39+
For generating clustered tensors we will need two things: cluster
40+
centroids and the final shape tensor must have.
41+
42+
Args:
43+
clusters_centroids: An array of shape (N,) that contains initial
44+
values of clusters centroids.
45+
"""
46+
self.cluster_centroids = clusters_centroids
47+
48+
@abc.abstractmethod
49+
def get_pulling_indices(self, weight):
50+
"""Returns indices of closest cluster centroids.
51+
52+
Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the
53+
same shape that will hold indices of cluster centroids clustered arrays
54+
elements will be pulled from.
55+
56+
In the current setup pulling indices are meant to be created once and
57+
used everywhere
58+
59+
Args:
60+
weight: ND array of weights. For each weight in this array the
61+
closest cluster centroids is found.
62+
63+
Returns:
64+
ND array of the same shape as `weight` parameter of the type
65+
tf.int32. The returned array contain weight lookup indices
66+
"""
67+
pass
68+
69+
@tf.custom_gradient
70+
def add_custom_gradients(self, clst_weights, weights):
71+
"""Adds custom gradients in the backprop stage.
72+
73+
This function overrides gradients in the backprop stage: original mul
74+
becomes add, tf.sign becomes tf.identity. It is to update the original
75+
weights with the gradients updates directly from the layer wrapped. We
76+
assume the gradients updates on individual elements inside a cluster
77+
will be different so that there is no point of mapping the gradient
78+
updates back to original weight matrix using the LUT.
79+
80+
Args:
81+
clst_weights: cluster weights
82+
weights: weights
83+
Returns:
84+
custom gradient
85+
"""
86+
override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6)
87+
z = clst_weights * override_weights
88+
89+
def grad(dz):
90+
return dz, dz
91+
92+
return z, grad
93+
94+
def get_clustered_weight(self, pulling_indices):
95+
"""Returns clustered weights.
96+
97+
Takes an array with integer number that represent lookup indices and
98+
forms a new array according to the given indices.
99+
100+
Args:
101+
pulling_indices: an array of indices used for lookup.
102+
Returns:
103+
An array with the same shape as `pulling_indices`. Each array element
104+
is a member of self.cluster_centroids.
105+
"""
106+
107+
return tf.reshape(
108+
tf.gather(self.cluster_centroids,
109+
tf.reshape(pulling_indices, shape=(-1,))),
110+
shape=pulling_indices.shape)
111+
112+
def get_clustered_weight_forward(self, pulling_indices, weight):
113+
"""Returns clustered weights with custm gradiennt.
114+
115+
Takes indices (pulling_indices) and original weights (weight) as inputs
116+
117+
and then forms a new array according to the given indices. The original
118+
weights (weight) here are added to the graph since we want the backprop
119+
to update their values via the new implementation using tf.custom_gradient
120+
121+
Args:
122+
pulling_indices: an array of indices used for lookup.
123+
weight: the original weights of the wrapped layer.
124+
Returns:
125+
array with the same shape as `pulling_indices`. Each array element
126+
is a member of self.cluster_centroids
127+
"""
128+
129+
x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,))
130+
131+
return tf.reshape(
132+
self.add_custom_gradients(x, tf.reshape(weight, shape=(-1,))),
133+
pulling_indices.shape)

0 commit comments

Comments
 (0)